pytorch 保存模型和加载模型

2023年6月9日09:05:21
def save_model(save_dir, phase, name, epoch, f1score, model):
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_dir = os.path.join(save_dir, args.model)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_dir = os.path.join(save_dir, phase)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    state_dict = model.state_dict()
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].cpu()
    state_dict_all = {
        'state_dict': state_dict,
        'epoch': epoch,
        'f1score': f1score,
    }
    torch.save(state_dict_all, os.path.join(save_dir, '{:s}.ckpt'.format(name)))
    if 'best' in name and f1score > 0.3:
        torch.save(state_dict_all, os.path.join(save_dir, '{:s}_{:s}.ckpt'.format(name, str(epoch))))

pytorch 保存模型

pytorch 加载模型进行继续训练

    if args.resume:
        state_dict = torch.load(args.resume)
        model.load_state_dict(state_dict['state_dict'])
        best_f1score = state_dict['f1score']
        start_epoch = state_dict['epoch'] + 1

 

  • 作者:eilot_c
  • 原文链接:https://blog.csdn.net/eilot_c/article/details/103293552
    更新时间:2023年6月9日09:05:21 ,共 828 字。