9 个回答
首先我不是很清楚你这个第一轮指得是epoch还是iteration,如果是epoch网上有很多方法(调学习率啥啥啥的),但是如果你是第一个iteration后就出现nan,那或许我接下来说的东西可能会对你有帮助( 是可能哈,不是一定 )。
(第一次玩知乎,写得啰嗦或啥的别怪,主要是为了把问题阐述清楚,因为第一个iteration后出现nan的解决方法基本找不到,我先po个照片关于第一个iteration后就出nan的情况,不会排版,大家将就看吧。 我会写得很多,写那么多也是为了记录自己科研之路痛苦和快乐)
如果你很急着看解决方案,你可以看完我的实验环境后(紧随图后),直接跳到最后解决方案那一块
先说下我自己做的实验,ubuntu16.04上训练VGG16,用的是pytorch(1.4.0),数据集则是CIFAR-10。大家在网上抄模型的时候千万注意,虽然很感谢他们资源共享,但是有些人的模型直接把VGG16全连接层都给删除了(欲哭无泪),但是也就是因为这个错误的模型帮助了我发现了问题的入口哈哈哈。建议大家使用torchvision里面的模型,因为pytorch自己写的模型里面的模型初始化还是做得蛮好的(我自己在这个实验中也是用的torchvision的VGG16模型 pretrain=False)。
前提条件都介绍完了,进行实验后就发现上图的问题,第一个iteration后出现nan,查看模型各层的weight和grad后也全为nan。将同样的模型以及代码放到windows环境下,居然发现啥事没有!!!额的个亲娘!就开始了我漫长的查找解决方案过程,整个过程中我试过:
- 调整学习率(没用)
- 梯度裁剪(没用)
- 改batch_size(没用,我都改到过bsz=1仍然没用)
- 检查原始数据是否存在缺失值nan(不可能呀,我用的官方数据集呀,查了之后也木有任何问题)
- 检查浮点数精度(没用)
- 是否有除0出现(这个排查得不彻底,但是凭我第六感感觉就不可能,如果有除零出现,为啥Windows啥事没有)
- batchNorm(没用)
- ....有些都不记得了反正网上说的我都试过了(结果没成功)
诶!会不会是模型参数多了?改!改卷积层,该删的删,总算是把模型搞小了,一跑,nan,就像我身上的肥肉一般对我不离不弃。心态崩了,我太难了。
重新看梯度,发现最底层两层全连接层梯度正常,到了nn.Linear(512,4096)(大概是这一层)发现梯度全为nan,那说明是不是全连接层出问题了,开始实验.。全连接层只留下nn.Linear(512,4096),奇怪的事发生了,nan问题没有出现,第一次iteration后也没有出现nan,输出正常(但当然我的num_classes=10,这样肯定是不对的),那试试nn.Linear(512,10000)(全连接只有这一层,卷积层啥也没改),我去也没问题!!!我就不停的改nn.linear(512,x)中的x,发现当x = 292时就开始出现nan(x =293正常)。试试再加几层!比如
nn.Linear(512,4096)
(自己脑补ReLU dropout等基本操作,这不写了)
nn.Linear(4096,4096)
........(很多很多层乱写的)
nn.Linear(100000,293)
我去,这居然都能输出正常值,没有nan,但是我只要把293改为292或者292以下的数,就会产生nan.诶这又是为何?当我把这个玄学问题告诉师姐的时候,师姐说她那里也跑的VGG16,能跑通呀!(环境啥都一样,pytorch版本,Ubuntu版本等)打开VGG16一看,这个跑通的VGG16就是我前面说的网上复制的VGG16(没有全连接层的fake vgg16),依照这个重大发现,我修改我自己的torchvision下的VGG16,将全连接层全部删除,我去!居然跑通了!困扰我快一个星期的问题总算是找到问题的入口了---->全连接层!!!
(你能阅读到这说明你也是被这个问题困扰了很久的小伙伴,给你个爱的拥抱)
在师姐和我的努力下,我们按照这个思想把曾经一直没跑通的模型googlenet、resnet、densenet的全连接层全部去掉(删掉所有的nn.Linear),瞬间开心,都跑通了,没出现iteration第一个后出现nan问题!!!!那就开始想解决办法吧。
查来查去这个url的问题讨论算是解决了我的问题,开心!周末可以去吃烤串了!!!(英语感谢写得很蹩脚的那个是我)
https:// discuss.pytorch.org/t/w ell-formed-input-into-a-simple-linear-layer-output-nan/74720
(恭喜你看到这了,你一定也是一个像我一样的小可爱)
最终方案:升级你的numpy!!!!啊哈哈哈哈哈哈哈,谁能想到!我之前的numpy版本是1.18.1,升级后的numpy版本是1.19.2,这种第一个iteration后出现nan问题再也没出现了(目前实验了VGG16、Googlenet、resnet50、densenet)。哎!问题解决了是很开心,但是问题的本质是什么呢?有大佬来解读下吗?
欢迎一起讨论哈哈哈,我是个科研小白,很多东西也都是第一次接触,有啥说的不对的,互相讨论,这样又能学到新知识了!开心~。不说了,继续做实验去了。
(补充一下:我所有模型实验都是在cpu上跑的)