1. torch.save
2. torch.load
方法1:保存整个module (耗时,占内存)
torch.save(net.path)
path_model = './model.pkl'
net_load = torch.load(path_model)
方法2:保存模型参数(官方推荐)
state_dict = net.state_dict()
torch.save(state_dict, path)
path_state_dict = './model_state_dict.pkl'
state_dict_load = torch.load(path_state_dict)
net.load_state_dict(state_dict_load)
3. 断点续存训练
保存断点(在epoch循环中):
if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次
checkpoint = {"model_state_dict": net.state_dict() # 模型数据
"optimizer_state_dict": optimizer.state_dict() # 优化器数据
"epoch": epoch # 迭代次数
path_checkpoint = './checkpoint_{}_epoch.pkl'.format(epoch)
torch.save(checkpoint, path_checkpoint)
断点恢复:
path_checkpoint = './checkpoint_4_epoch.pkl'
checkpoint = torch.load(path_checkpoint)