记录一次Loss爆炸的经历

背景

最近正在做关于某个之前没用过的预训练模型的项目,写完代码在本地( macbook )跑通后把代码丢上了服务器,然后就跑去快乐地刷b站了,结果过了一会回来发现 loss 变成了 NaN ,在google上简单地搜索一下了并没有找到合适的解决方案,于是开始了漫漫调试路

结论

这里先给出bug的成因和解决方案,对于debug的过程感兴趣的小伙伴可以看下一节

  • 成因 :由于模型在训练过程中使用了半精度,而 AdamW 的默认 eps 1e-8 ,因此出现数值溢出,在进行一次反向传播后,部分模型参数变成 inf ,导致输出的 logits 变成 NaN ,从而导致 loss 变成 NaN
  • 解决方案 :(以下方案任选其一)
    • 方案一 :混合精度训练的时候, AdamW eps 设置成 1e-3
    • 方案二 :读入模型参数的时候, torch.load(ckpt_path, map_location='cpu') ,然后再把模型放到对应的显卡上,这样模型参数的类型就不是 fp16 ,而是 fp32 ,就不会发生数值溢出的问题了
    • 方案三 :在代码中加上对半精度训练的处理代码,即引入 autocast scaler

初步分析

由于之前遇到过显卡抽风,导致 loss 变成 NaN 的情况,所以这次果断地选择换一块显卡再跑跑看。结果非常悲剧地还是NaN,不信邪的我又把代码丢上Colab,仍然是 NaN ,于是不得不接受代码出bug了的悲惨事实

debug的第一步当然是在本地复现出bug,然而模型在训练了若干个batch之后, loss 仍旧在非常正常地下降,而之前在服务器或者Colab上运行时都是训了一个batch后, loss 就立刻变成了 NaN

此时的我颇感惊讶,思考了一下本地环境和服务器环境的区别,感觉主要的区别是 batch size 不同,猜测可能是服务器上 batch size 开太大了,导致 loss 数值上溢(本地 batch size 是2,服务器是128)

然而当时在本地跑的时候, loss 其实也就零点几,而且计算 loss 时也是取的 mean 而不是 sum ,所以内心其实就已经不太认可这个猜测,但姑且还是先尝试一下,事实证明确实不是 batch size 的问题, loss 仍然在训练了一个batch后变成了 NaN

结合第一个batch可以正确得到 loss ,而第二个batch开始 loss 变成 NaN 的现象, 猜测可能是第一个batch的 loss 把模型参数训拉垮了 ,导致第二个batch的 loss 激增。

逐步排查

此时PLM部分的参数的学习率设置的是 5e-5 ,非PLM部分的参数的学习率设置的是 1e-3 ,之前BERT和T5用这些参数都是ok的,所以首先排除超参数不合理的可能性

然而在排除这一可能性之后,我陷入了深深的懵逼之中,注意力也从 “为什么本地无法复现bug” 转向了 “为什么模型会被训拉垮”

思考了一下,得出另一个猜想—— 是不是某条数据太难了,导致 loss 飙升成 NaN ,从而破坏了模型参数,导致后续的 loss 都变成 NaN (其实这条猜想和前面本地可以正常训练的事实矛盾了,但当时并没有意识到这一点,不过幸好当时没意识到,否则就不能一步步debug下去,最终找到问题根源了)

于是决定把学习率都设置成0,看看在完全固定模型参数的前提下 loss 是不是都可以被正常的计算出来,结果发现并非如此, loss 仍然在经过一个batch之后变成了 NaN ,此时的我陷入了更深的懵逼之中—— 难道我把学习率调成0之后, AdamW 仍然会更新模型参数? 此时跑去看了一下pytorch的官方文档,发现在学习率为0的情况下, AdamW 不应该会更新模型参数才对

于是又思考了一下,决定把优化器从 AdamW 换成 SGD 再试试,扔上服务器之后发现在学习率为0的情况下,模型可以正确的计算 loss ,又试着把学习率调回去,发现模型也可以正确的进行训练,并在验证集上取得良好性能,此时我更更更更懵逼了,总不能 AdamW 还不如 SGD 吧,这也太不符合常识了,但目前的事实又确实应证了我前面的猜想—— AdamW 把模型参数训拉了

又又又思考了一下,决定把计算 loss 前的 logits golds 先print出来,如果确实是模型参数训拉了,导致 logits 太差,造成 loss 过大变成 NaN 的话,这样做至少能通过这个 logits golds 在本地复现出这个bug了

问题解决

在打印出logits后发现 logits 全是 NaN ,也就是说之前的 loss 变成 NaN 其实是 logits 导致的(这里吐槽一句,pytorch的 cross_entropy 函数在 input NaN 的时候居然不报错,要不是我习惯把 loss 更新在进度条上,我估计都发现不了问题)

logits 全是 NaN 肯定是模型参数的问题,那么首先 打印 全部的模型参数,看一看具体是哪个参数的问题,发现大多数参数其实都变成了 inf ,于是随便挑了其中一个参数,在计算完第一个batch之后,把它的梯度 打印 出来

然而,思路在这边又断开了——这个参数的梯度是正常的,而在执行完 optim.step() 后,这个参数也确实变成了inf。此时的我一脸的黑人问号,非常不信邪地在本地又来了一遍,结果是本地的参数也是正常的,区别在于,本地执行完 optim.step() 后,参数仍然是正常的

正当我生无可恋地随意翻看服务器上的日志时,我突然发现这些模型的类别居然全部都是 float16 ,于是我赶紧回去看了眼本地的日志,发现类别全部都是 float32 ,此时的我终于意识到很可能是模型在cuda下开启了半精度加速,从而导致的数值溢出

在google上搜索“pytorch 半精度 loss NaN”这几个关键字后,终于找到了答案——由于模型在训练过程中使用了半精度,而 AdamW 的默认 eps 1e-8 ,因此出现数值溢出,解决方案是将 eps 修改为 1e-3

总结

在知道答案后再回头看的话会发现一切的根源都是我太菜了,虽然早就知道混合精度的存在,但一直懒得去学,所以直到遇见今天这个问题之前,甚至都不知道混合精度是cuda下面的东西,如果能知道这一基本常识的话,恐怕在发现本地可以正常训练的时候就能意识到是半精度的问题了。debug用掉的这一晚上时间就当是交学费了。

编辑于 2022-10-14 17:09

文章被以下专栏收录