相关文章推荐
从未表白的毛衣  ·  java poi ...·  1 月前    · 
一直单身的台灯  ·  机器学习加速氧化还原电位和酸度常数计算·  8 月前    · 
机灵的铁链  ·  我对python中的wget模块有问题。-腾 ...·  12 月前    · 
骑白马的毛衣  ·  将textarea值转换为有效的JSON字符 ...·  1 年前    · 
完美的牛肉面  ·  如何将文件流转换成byte[]数组-腾讯云开 ...·  1 年前    · 
高大的黑框眼镜  ·  vuejs ...·  1 年前    · 
Code  ›  optimizer.zero_grad()开发者社区
梯度 target
https://cloud.tencent.com/developer/article/1700045
失眠的红豆
1 年前
作者头像
狼啸风云
0 篇文章

optimizer.zero_grad()

前往专栏
腾讯云
开发者社区
文档 意见反馈 控制台
首页
学习
活动
专区
工具
TVP
文章/答案/技术大牛
发布
首页
学习
活动
专区
工具
TVP
返回腾讯云官网
社区首页 > 专栏 > 计算机视觉理论及其实现 > optimizer.zero_grad()

optimizer.zero_grad()

作者头像
狼啸风云
修改 于 2022-09-02 20:45:17
6.1K 1
修改 于 2022-09-02 20:45:17
举报

传统的训练函数,一个batch是这么训练的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)
    # 2. backward
    optimizer.zero_grad()   # reset gradient
    loss.backward()
    optimizer.step()            
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. optimizer.zero_grad() 清空过往梯度;
  3. loss.backward() 反向传播,计算当前梯度;
  4. optimizer.step() 根据梯度更新网络参数

简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络,使用梯度累加是这么写的:

for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)
    # 2.1 loss regularization
    loss = loss/accumulation_steps   
    # 2.2 back propagation
 
推荐文章
从未表白的毛衣  ·  java poi 环图_mob649e8166858d的技术博客_
1 月前
一直单身的台灯  ·  机器学习加速氧化还原电位和酸度常数计算
8 月前
机灵的铁链  ·  我对python中的wget模块有问题。-腾讯云开发者社区-腾讯云
12 月前
骑白马的毛衣  ·  将textarea值转换为有效的JSON字符串-腾讯云开发者社区-腾讯云
1 年前
完美的牛肉面  ·  如何将文件流转换成byte[]数组-腾讯云开发者社区-腾讯云
1 年前
高大的黑框眼镜  ·  vuejs -来自chokidar的错误(C:\):Error: EBUSY:资源繁忙或锁定,lstat 'C:\hiberfil.sys‘-腾讯云开发者社区-腾讯云
1 年前
今天看啥   ·   Py中国   ·   codingpro   ·   小百科   ·   link之家   ·   卧龙AI搜索
删除内容请联系邮箱 2879853325@qq.com
Code - 代码工具平台
© 2024 ~ 沪ICP备11025650号