相关文章推荐
爱玩的跑步鞋  ·  [Unity3D] ...·  8 月前    · 
酷酷的鸵鸟  ·  Alternate App Icons ...·  11 月前    · 
无邪的大熊猫  ·  ios开发devece ...·  1 年前    · 
有情有义的蚂蚁  ·  Python反反爬虫 - ...·  1 年前    · 
首发于 通用DL
【Pytorch】CrossEntropy和NLLLoss的关系

【Pytorch】CrossEntropy和NLLLoss的关系

  • CrossEntropy是LogSoftmax+NLLLoss的结果
  • NLLLoss默认参数下是目标位置的LogSoftmax的均值取负号。具体地,NLLLoss可以配置一些参数,如weight、reduction等,公式如下:
torch.nn.NLLLoss( weight=None , size_average=None , ignore_index=- 100 , reduce=None , reduction='mean' )
  • 数据准备
# 数据准备
import torch
from torch.nn import functional as F
torch.manual_seed(2022)
x=torch.rand(3,5).masked_fill(torch.randint(0,2,[3,5]).bool(),float('-inf'))
z=F.log_softmax(x,dim=-1)
y=torch.LongTensor([3,4,-100])
print('x:\n',x)
print('y:\n',y)
print('z:\n',z)
  • 实验对比
print('说明CrossEntropy是LogSoftmax+NLLLoss:')
print('ce:',F.cross_entropy(x,y,ignore_index=-100))
print('ls+nll:',F.nll_loss(z,y,ignore_index=-100))
print()
print('说明ignore_index的作用:')