训练深度学习网络时候,出现Nan是什么原因,怎么才能避免?
48 个回答
最近做了一组实验,每次在固定的迭代次数段,都会loss突然变nan,导致acc骤降,慢慢变0。
于是找啊找啊找bug……
很难受,在意志力的坚持下,找到海枯石烂终于知道了!
loss突然变nan的原因,很可惜并不是这里其他所有答主所说的 “因为梯度爆炸”、“lr过大”、“不收敛” 等等原因,而是因为 training sample中出现了脏数据 !
脏数据的出现导致我的logits计算出了0,0传给 log(x|x=0) \rightarrow ∞, 即nan。
所以我通过设置batch_size = 1,shuffle = False,一步一步地将sample定位到了 所有可能的脏数据,删掉 。期间,删了好几个还依然会loss断崖为nan,不甘心,一直定位一直删。终于tm work out!
之所以会这样,是因为我的实验是实际业务上的 真实数据 ,有实际经验的就知道的,现实的数据非常之脏,基本上数据预处理占据我80%的精力。
好怀念以前可以天真快乐的在open dataset上做task跑模型的时候,真是啥都不用管,专注模型算法……