def save_model(save_dir, phase, name, epoch, f1score, model):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_dir = os.path.join(save_dir, args.model)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_dir = os.path.join(save_dir, phase)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
state_dict = model.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
state_dict_all = {
'state_dict': state_dict,
'epoch': epoch,
'f1score': f1score,
}
torch.save(state_dict_all, os.path.join(save_dir, '{:s}.ckpt'.format(name)))
if 'best' in name and f1score > 0.3:
torch.save(state_dict_all, os.path.join(save_dir, '{:s}_{:s}.ckpt'.format(name, str(epoch))))
pytorch 保存模型
pytorch 加载模型进行继续训练
if args.resume:
state_dict = torch.load(args.resume)
model.load_state_dict(state_dict['state_dict'])
best_f1score = state_dict['f1score']
start_epoch = state_dict['epoch'] + 1