PyTorch中在反向传播前为什么要手动将梯度清零?
11 个回答
这种模式提供给用户更多的自由度,把梯度玩出花样,比如说梯度累加(gradient accumulation)
传统的训练函数,一个batch是这么训练的:
for i, (image, label) in enumerate(train_loader):
# 1. forward
pred = model(image)
loss = criterion(pred, label)
# 2. backward
loss.backward()
# 3. update parameters of net
optimizer.step()
# 4. reset gradient
optimizer.zero_grad()
- model.forward():前向推理,计算损失函数;
- loss.backward():反向传播,计算当前梯度;
- optimizer.step():根据梯度更新网络参数;
- optimizer.zero_grad():清空梯度;
简单的说就是进来一个 batch 的数据,计算一次梯度,更新一次网络
使用梯度累加是这么写的:
for i, (image, label) in enumerate(train_loader):
# 1. forward
pred = model(image)
loss = criterion(pred, label)
# 2. backward
loss = loss / accumulation_steps
loss.backward()
# 3. update parameters of net
if (i + 1) % accumulation_steps == 0:
# 4.1 update parameters of net
optimizer.step()
# 4.2 reset gradient
optimizer.zero_grad()
- optimizer.zero_grad():清空过往梯度
- model.forward():前向推理,计算损失函数;
- loss.backward():反向传播,计算当前梯度;
- optimizer.step():多次循环步骤 2-3,梯度累加一定次数后,根据梯度更新网络参数,然后清空梯度
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
一定条件下,batchsize 越大训练效果越好,梯度累加则实现了 batchsize 的变相扩大,如果accumulation_steps 为 8,则batchsize '变相' 扩大了8倍,使用时需要注意,学习率也要适当放大。
更新1:关于BN是否有影响,BN的估算是在forward阶段就已经完成的,并不冲突
As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don't do .backward() every time.
更新2:根据 @李韶华 的分享,可以适当调低BN自己的momentum参数
bn自己有个momentum参数:x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum越接近0,老的running stats记得越久,所以可以得到更长序列的统计信息