PyTorch 之模型的保存与加载

1. torch.save
  • obj: 对象
  • f:输出路径
  • 2. torch.load
  • f: 文件路径
  • map_location: 指定存放位置,cpu or gpu
  • 方法1:保存整个module (耗时,占内存)
    torch.save(net.path)
    
    path_model = './model.pkl'
    net_load = torch.load(path_model)
    
    方法2:保存模型参数(官方推荐)
    state_dict = net.state_dict()
    torch.save(state_dict, path)
    
    path_state_dict = './model_state_dict.pkl'
    state_dict_load = torch.load(path_state_dict)
    net.load_state_dict(state_dict_load)
    
    3. 断点续存训练

    保存断点(在epoch循环中):

    if (epoch + 1) % checkpoint_interval == 0:  # 每隔checkpoint_interval保存一次
        checkpoint = {"model_state_dict": net.state_dict()  # 模型数据
                      "optimizer_state_dict": optimizer.state_dict()  # 优化器数据
                      "epoch": epoch  # 迭代次数
        path_checkpoint = './checkpoint_{}_epoch.pkl'.format(epoch)
        torch.save(checkpoint, path_checkpoint)
    

    断点恢复:

    path_checkpoint = './checkpoint_4_epoch.pkl'
    checkpoint = torch.load(path_checkpoint)