程序的入口
一般完整的网络训练过程包含几个部分:
- build dataset and data loader
- build model
- design loss
- build optimizer
- workflow: train, validate, checkpoint and log iteratively
前面四个过程是在main和train_detector这两个函数中实现的,全部是采用Registry工厂模式实现的(具体可参考here1和here2),在这里根据cfg build出来。(另外,在mmdetection中design loss是作为build model的一部分实现的)
- main定义在tools/train.py,是整个程序的入口
- train_detector定义在mmdet/apis/train.py
它们的代码简要如下:
# tools/train.pydefmain():
args= parse_args()
cfg= Config.fromfile(args.config)if args.cfg_optionsisnotNone:
cfg.merge_from_dict(args.cfg_options)# 1. merge cfg; 2. create work_dir 3. create distribution train world with given GPU ids 4. collect env ......
model= build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
datasets=[build_dataset(cfg.data.train)]...# add an attribute for visualization convenience
model.CLASSES= datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
deftrain_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):...
data_loaders=[
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,# cfg.gpus will be ignored if distributedlen(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed)for dsin dataset]# put model on gpusif distributed:
find_unused_parameters= cfg.get('find_unused_parameters',False)# Sets the `find_unused_parameters` parameter in# torch.nn.parallel.DistributedDataParallel
model= MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)else:
model= MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)# build runner
optimizer= build_optimizer(model, cfg.optimizer)if'runner'notin cfg:
cfg.runner={'type':'EpochBasedRunner','max_epochs': cfg.total_epochs}...
runner= build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))...# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config',None))if distributed:ifisinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())# register eval hooksif validate:...
val_dataset= build_dataset(cfg.data.val,dict(test_mode=True))
val_dataloader= build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)...
eval_hook= DistEvalHookif distributedelse EvalHook
runner.register_hook(eval_hook(val_dataloader,**eval_cfg))# user-defined hooks...if cfg.resume_from:
runner.resume(cfg.resume_from)elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
Runner
上面五个步骤(在mmdetection中往往将design loss作为build model的一部分)中前四个步骤在上一个章节已经知道了其创建方式,而最后一个步骤的做法很多操作基本是固定的,因此可以把最后一个操作封装起来,同时允许在每个关键位置上添加Hook(回调函数)的方式来实现拓展,这两个事情就是Runner要实现的功能。下面以默认最常用的EpochBasedRunner为例进行说明,Runner的函数调用如下所示,以run为入口函数,里面调用train/val,它们各自调用run_iter函数。
@RUNNERS.register_module()classEpochBasedRunner(BaseRunner):"""Epoch-based Runner.
This runner train models epoch by epoch.
"""defrun(self, data_loaders, workflow, max_epochs=None,**kwargs):"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""...
self.call_hook('before_run')while self.epoch< self._max_epochs:for i, flowinenumerate(workflow):
mode, epochs= flow...
epoch_runner=getattr(self, mode)...for _inrange(epochs):if mode=='train'and self.epoch>= self._max_epochs:break
epoch_runner(data_loaders[i],**kwargs)
time.sleep(1)# wait for some hooks like loggers to finish
self.call_hook('after_run')deftrain(self, data_loader,**kwargs):
self.model.train()
self.mode='train'
self.data_loader= data_loader
self._max_iters= self._max_epochs*len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
self._inner_iter= i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True,**kwargs)
self.call_hook('after_train_iter')
self._iter+=1
self.call_hook('after_train_epoch')
self._epoch+=1@torch.no_grad()defval(self, data_loader,**kwargs):
self.model.eval()
self.mode='val'
self.data_loader= data_loader
self.call_hook('before_val_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
self._inner_iter= i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')defrun_iter(self, data_batch, train_mode,**kwargs):if self.batch_processorisnotNone:
outputs= self.batch_processor(
self.model, data_batch, train_mode=train_mode,**kwargs)elif train_mode:
outputs= self.model.train_step(data_batch, self.optimizer,**kwargs)else:
outputs= self.model.val_step(data_batch, self.optimizer,**kwargs)ifnotisinstance(outputs,dict):raise TypeError('"batch_processor()" or "model.train_step()"''and "model.val_step()" must return a dict')if'log_vars'in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs= outputsdefsave_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):...
Hook
上面代码里的self.call_hook就是在调用对应的hook函数们。
self.call_hook('before_run')
self.call_hook('before_train_epoch')/ self.call_hook('before_val_epoch')
self.call_hook('before_train_iter')/ self.call_hook('before_val_iter')
self.call_hook('after_train_iter')/ self.call_hook('after_val_iter')
self.call_hook('after_train_epoch')/ self.call_hook('after_val_epoch')
self.call_hook('after_run')
Runner中关于Hook的注册与调用都是在Hook的基类BaseRunner中实现,对应函数register_hook和call_hook
classBaseRunner(metaclass=ABCMeta):...defregister_hook(self, hook, priority='NORMAL'):"""Register a hook into the hook list.
"""assertisinstance(hook, Hook)ifhasattr(hook,'priority'):raise ValueError('"priority" is a reserved attribute for hooks')
priority= get_priority(priority)
hook.priority= priority# insert the hook to a sorted list
inserted=Falsefor iinrange(len(self._hooks)-1,-1,-1):if priority>= self._hooks[i].priority:
self._hooks.insert(i+1, hook)
inserted=Truebreakifnot inserted:
self._hooks.insert(0, hook)defregister_hook_from_cfg(self, hook_cfg):"""Register a hook from its cfg.
"""
hook_cfg= hook_cfg.copy()
priority= hook_cfg.pop('priority','NORMAL')
hook= mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)defcall_hook(self, fn_name):"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""for hookin self._hooks:getattr(hook, fn_name)(self)defget_hook_info(self):...defresume(self,...defregister_lr_hook(self, lr_config):if lr_configisNone:returnelifisinstance(lr_config,dict):assert'policy'in lr_config
policy_type= lr_config.pop('policy')if policy_type== policy_type.lower():
policy_type= policy_type.title()
hook_type= policy_type+'LrUpdaterHook'
lr_config['type']= hook_type
hook= mmcv.build_from_cfg(lr_config, HOOKS)else:
hook= lr_config
self.register_hook(hook, priority='VERY_HIGH')
Hook类的定义如下,是插个各个位置的回调函数的类,我们可以根据自己的需要继承Hook类并重写对应位置的函数
classHook:
stages=('before_run','before_train_epoch','before_train_iter','after_train_iter','after_train_epoch','before_val_epoch','before_val_iter','after_val_iter','after_val_epoch','after_run')defbefore_run(self, runner):passdefafter_run(self, runner):passdefbefore_epoch(self, runner):passdefafter_epoch(self, runner):passdefbefore_iter(self, runner):passdefafter_iter(self, runner):passdefbefore_train_epoch(self, runner):
self.before_epoch(runner)defbefore_val_epoch(self, runner):
self.before_epoch(runner)defafter_train_epoch(self, runner):
self.after_epoch(runner)defafter_val_epoch(self, runner):
self.after_epoch(runner)defbefore_train_iter(self, runner):
self.before_iter(runner)defbefore_val_iter(self, runner):
self.before_iter(runner)defafter_train_iter(self, runner):
self.after_iter(runner)defafter_val_iter(self, runner):
self.after_iter(runner)defevery_n_epochs(self, runner, n):return(runner.epoch+1)% n==0if n>0elseFalsedefevery_n_inner_iters(self, runner, n):return(runner.inner_iter+1)% n==0if n>0elseFalsedefevery_n_iters(self, runner, n):return(runner.iter+1)% n==0if n>0elseFalsedefend_of_epoch(self, runner):return runner.inner_iter+1==len(runner.data_loader)defis_last_epoch(self, runner):return runner.epoch+1== runner._max_epochsdefis_last_iter(self, runner):return runner.iter+1== runner._max_itersdefget_triggered_stages(self):
trigger_stages=set()for stagein Hook.stages:if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)# some methods will be triggered in multi stages# use this dict to map method to stages.
method_stages_map={'before_epoch':['before_train_epoch','before_val_epoch'],'after_epoch':['after_train_epoch','after_val_epoch'],'before_iter':['before_train_iter','before_val_iter'],'after_iter':['after_train_iter','after_val_iter'],}for method, map_stagesin method_stages_map.items():if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)return[stagefor stagein Hook.stagesif stagein trigger_stages]
在mmdet/apis/train.py: train_detector(该函数被训练的入口函数 train.py: main函数调用)中有以下相关的实现:
- lr policy, optimizer, checkpoint, log, eval 都是通过Hook实现的
- 在config文件中可以通过custom_hooks参数配置自己的hook (Hook也是使用Registry来定义的)
# mmdet/apis/train.pydeftrain_detector(...# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config',None))if distributed:ifisinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())# register eval hooksif validate:...
eval_hook= DistEvalHookif distributedelse EvalHook
runner.register_hook(eval_hook(val_dataloader,**eval_cfg))# user-defined hooksif cfg.get('custom_hooks',None):
custom_hooks= cfg.custom_hooksassertisinstance(custom_hooks,list), \f'custom_hooks expect list type, but got{type(custom_hooks)}'for hook_cfgin cfg.custom_hooks:assertisinstance(hook_cfg,dict), \'Each item in custom_hooks expects dict type, but got ' \f'{type(hook_cfg)}'
hook_cfg= hook_cfg.copy()
priority= hook_cfg.pop('priority','NORMAL')
hook= build_from_cfg(hook_cfg