EpochBasedRunner 阅读记录
这个类有很多方法,这里只记录当前我所用到的方法。
一、首先是实例化这个类的初始化方法:
"""
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
"""def__init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):if batch_processorisnotNone:ifnotcallable(batch_processor):raise TypeError('batch_processor must be callable, 'f'but got{type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement ''train_step() and val_step() in the model instead.')# raise an error is `batch_processor` is not None and# `model.train_step()` exists.if is_module_wrapper(model):
_model= model.moduleelse:
_model= modelifhasattr(_model,'train_step')orhasattr(_model,'val_step'):raise RuntimeError('batch_processor and model.train_step()/model.val_step() ''cannot be both available.')else:asserthasattr(model,'train_step')# check the type of `optimizer`ifisinstance(optimizer,dict):for name, optimin optimizer.items():ifnotisinstance(optim, Optimizer):raise TypeError(f'optimizer must be a dict of torch.optim.Optimizers, 'f'but optimizer["{name}"] is a{type(optim)}')elifnotisinstance(optimizer, Optimizer)and optimizerisnotNone:raise TypeError(f'optimizer must be a torch.optim.Optimizer object 'f'or dict or None, but got{type(optimizer)}')# check the type of `logger`ifnotisinstance(logger, logging.Logger):raise TypeError(f'logger must be a logging.Logger object, 'f'but got{type(logger)}')# check the type of `meta`if metaisnotNoneandnotisinstance(meta,dict):raise TypeError(f'meta must be a dict or None, but got{type(meta)}')
self.model= model
self.batch_processor= batch_processor
self.optimizer= optimizer
self.logger= logger
self.meta= meta# create work_dirif mmcv.is_str(work_dir):
self.work_dir= osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)elif work_dirisNone:
self.work_dir=Noneelse:raise TypeError('"work_dir" must be a str or None')# get model name from the model classifhasattr(self.model,'module'):
self._model_name= self.model.module.__class__.__name__else:
self._model_name= self.model.__class__.__name__
self._rank, self._world_size= get_dist_info()
self.timestamp= get_time_str()
self.mode=None
self._hooks=[]
self._epoch=0
self._iter=0
self._inner_iter=0if max_epochsisnotNoneand max_itersisnotNone:raise ValueError('Only one of `max_epochs` or `max_iters` can be set.')
self._max_epochs= max_epochs
self._max_iters= max_iters# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer= LogBuffer()
实例化BaseRunner所传入的参数见源码的注释部分,这里需要注意几个点:
- model 中要有 train_step() 这个方法;
- batch_processor 这个参数默认为 None ,本人使用过程在它也一直是 None ,这里暂时不考虑这个参数;
- optimizer 这个参数是可以为字典的,也就是说 BaseRunner 允许使用多个优化器优化网络不同位置的参数;
- max_epochs 和 max_iters 不能同时赋值,这两个参数我们给一个就行。
然后看一下初始化方法都做了哪些事:
- 检查传入的参数是否合法;
- 将相关参数保存在BaseRunner类的属性中;
- 创建工作目录以及获取设备的相关信息;
- 创建相关属性,用于记录训练过程中用到的一些参数。
二、注册训练Hook
一般情况下,当我们实例化一个BaseRunner类以后,都要执行 register_training_hooks() 方法,这个方法会调用BaseRunner中注册Hook的7个方法。关于mmdetection中的Hook,大家可以自行百度,个人感觉他的工作机制类似于单片机中的中断。
defregister_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):"""Register default and custom hooks for training.
Default and custom hooks include:
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
这里传入的参数都是配置文件中定义的字典,比如在 schedule_1x.py 文件中定义了:
lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21])
具体各个键值对的功能我们用到的时候再解释。
下面分别介绍以下各个类型Hook的注册过程。
1、register_lr_hook
defregister_lr_hook(self, lr_config):if lr_configisNone:returnelifisinstance(lr_config,dict):assert'policy'in lr_config
policy_type= lr_config.pop('policy')if policy_type== policy_type.lower():
policy_type= policy_type.title()
hook_type= policy_type+'LrUpdaterHook'
lr_config['type']= hook_type
hook= mmcv.build_from_cfg(lr_config, HOOKS)else:
hook= lr_config
self.register_hook(hook, priority='VERY_HIGH')
这个方法首先会判断传入的参数,如果没有传入参数,就不会注册管理 lr 的 Hook;如果是一个字典就根据字典中的信息注册一个 Hook (其实就是更加字典实例化一个相应的对象),关于 mmcv.build_from_cfg 方法,请大家自己百度;如果传入了参数但是参数不是字典,就默认传入的是一个已经注册好的 Hook 。
根据本人实验过程中传入的参数:lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21]), 其实就是将字典中的items作为参数实例化一个名为 StepLrUpdaterHook 的对象, StepLrUpdaterHook 这个类的位置在mmcv/runner/hooks/lr_updater.py 中。
最后调用了 self.register_hook 方法,将实注册的 Hook 保存到这个类的 _hooks 属性中,方便后面的使用。
defregister_hook(self, hook, priority='NORMAL'):"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""assertisinstance(hook, Hook)ifhasattr(hook,'priority'):raise ValueError('"priority" is a reserved attribute for hooks')
priority= get_priority(priority)
hook.priority= priority# insert the hook to a sorted list
inserted=Falsefor iinrange(len(self._hooks)-1,-1,-1):if priority>= self._hooks[i].priority:
self._hooks.insert(i+1, hook)
inserted=Truebreakifnot inserted:
self._hooks.insert(0, hook)
简单说一下 register_hook 这个方法,参数 hook 表示要保存的 Hook ,参数 priority 表示这个 Hook 的优先级,是不是越来越小单片机中的中断了。这里的优先级就是将Hook 放在 self._hooks 的什么位置,靠前的Hook在训练的时候会先执行。
2、register_momentum_hook
与register_lr_hook一样。
defregister_momentum_hook(self, momentum_config):if momentum_configisNone:returnifisinstance(momentum_config,dict):assert'policy'in momentum_config
policy_type= momentum_config.pop('policy')if policy_type== policy_type.lower():
policy_type= policy_type.title()
hook_type= policy_type+'MomentumUpdaterHook'
momentum_config['type']= hook_type
hook= mmcv.build_from_cfg(momentum_config, HOOKS)else:
hook= momentum_config
self.register_hook(hook, priority='HIGH')
3、register_optimizer_hook
defregister_optimizer_hook(self, optimizer_config):if optimizer_configisNone:returnifisinstance(optimizer_config,dict):
optimizer_config.setdefault('type','OptimizerHook')
hook= mmcv.build_from_cfg(optimizer_config, HOOKS)else:
hook= optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
4、register_checkpoint_hook
defregister_checkpoint_hook(self, checkpoint_config):if checkpoint_configisNone:returnifisinstance(checkpoint_config,dict):
checkpoint_config.setdefault('type','CheckpointHook')
hook= mmcv.build_from_cfg(checkpoint_config, HOOKS)else:
hook= checkpoint_config
self.register_hook(hook, priority='NORMAL')
5、register_logger_hooks
defregister_timer_hook(self, timer_config):if timer_configisNone:returnifisinstance(timer_config,dict):
timer_config_= copy.deepcopy(timer_config)
hook= mmcv.build_from_cfg(timer_config_, HOOKS)else:
hook= timer_config
self.register_hook(hook, priority='LOW')
6、register_logger_hooks
defregister_logger_hooks(self, log_config):if log_configisNone:return
log_interval= log_config['interval']for infoin log_config['hooks']:
logger_hook= mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
7、register_logger_hooks
defregister_custom_hooks(self, custom_config):if custom_configisNone:returnifnotisinstance(custom_config,list):
custom_config=[custom_config]for itemin custom_config:ifisinstance(item,dict):
self.register_hook_from_cfg(item)else:
self.register_hook(item, priority='NORMAL')
也就是说每个 Hook 都是一个类,每个类都有自己的功能,这些类的共同之处是都会有before_run、after_run、before_epoch、after_epoch、before_iter、after_iter 这些方法。
三、训练模型
defrun(self, data_loaders, workflow, max_epochs=None,**kwargs):"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""assertisinstance(data_loaders,list)assert mmcv.is_list_of(workflow,tuple)assertlen(data_loaders)==len(workflow)if max_epochsisnotNone:
warnings.warn('setting max_epochs in run is deprecated, please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs= max_epochsassert self._max_epochsisnotNone,('max_epochs must be specified during instantiation')for i, flowinenumerate(workflow):
mode, epochs= flowif mode=='train':
self._max_iters= self._max_epochs*len(data_loaders[i])break
work_dir= self.work_dirif self.work_dirisnotNoneelse'NONE'
self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s', self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs)
self.call_hook('before_run')while self.epoch< self._max_epochs:for i, flowinenumerate(workflow):
mode, epochs= flowifisinstance(mode,str):# self.train()ifnothasattr(self, mode):raise ValueError(f'runner has no method named "{mode}" to run an ''epoch')
epoch_runner=getattr(self, mode)else:raise TypeError('mode in workflow must be a str, but got {}'.format(type(mode)))for _inrange(epochs):if mode=='train'and self.epoch>= self._max_epochs:break
epoch_runner(data_loaders[i],**kwargs)
time.sleep(1)# wait for some hooks like loggers to finish
self.call_hook('after_run')
- 传入的参数如注释;
- 进入方法后还是先检查传入参数的格式;
- 然后根据 self._max_epochs 的值设置 self._max_iters 的值;
- 打印一些信息;
- 调用 call_hook 方法并传入 ‘before_run’,这个过程就是遍历我们之前注册的Hook,执行所有Hook的before_run方法;
- 开始循环训练,这个循环的功能是根据各种参数判断当前因该执行那个方法, self.train 还是 self.val 。
epoch_runner = getattr(self, mode)
,这里的mode就是就是train、val 。 - 最后执行
self.call_hook('after_run')
。
这里补充以下第5、7两点,至于调用self.call_hook('before_run')
和self.call_hook('after_run')
到底会干嘛,这个取决于我们刚刚定义的每一个Hook以及每个Hook对相关方法的实现。
defcall_hook(self, fn_name):for hookin self._hooks:getattr(hook, fn_name)(self)
call_hook 会比遍历我们呢注册的所有 Hook 并调用所有 Hook 中的 某个方法。比如self.call_hook('before_run')
就会调用所有 Hook 中的 before_run 方法。
最后还有三个方法,就是上面第6点 run 中可能选择的执行方法。
deftrain(self, data_loader,**kwargs):
self.model.train()
self.mode='train'
self.data_loader= data_loader
self._max_iters= self._max_epochs*len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
self._inner_iter= i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True,**kwargs)
self.call_hook('after_train_iter')
self._iter+=1
self.call_hook('after_train_epoch')
self._epoch+=1
@torch.no_grad()defval(self, data_loader,**kwargs):
self.model.eval()
self.mode='val'
self.data_loader= data_loader
self.call_hook('before_val_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
self._inner_iter= i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
defrun_iter(self, data_batch, train_mode,**kwargs):if self.batch_processorisnotNone:
outputs= self.batch_processor(self.model, data_batch, train_mode=train_mode,**kwargs)elif train_mode:
outputs= self.model.train_step(data_batch, self.optimizer,**kwargs)else:
outputs= self.model.val_step(data_batch, self.optimizer,**kwargs)ifnotisinstance(outputs,dict):raise TypeError('"batch_processor()" or "model.train_step()" and "model.val_step()" must return a dict')if'log_vars'in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])