【Pytorch】CrossEntropy和NLLLoss的关系
- CrossEntropy是LogSoftmax+NLLLoss的结果
- NLLLoss默认参数下是目标位置的LogSoftmax的均值取负号。具体地,NLLLoss可以配置一些参数,如weight、reduction等,公式如下:
torch.nn.NLLLoss( weight=None , size_average=None , ignore_index=- 100 , reduce=None , reduction='mean' )
- 数据准备
# 数据准备
import torch
from torch.nn import functional as F
torch.manual_seed(2022)
x=torch.rand(3,5).masked_fill(torch.randint(0,2,[3,5]).bool(),float('-inf'))
z=F.log_softmax(x,dim=-1)
y=torch.LongTensor([3,4,-100])
print('x:\n',x)
print('y:\n',y)
print('z:\n',z)
- 实验对比
print('说明CrossEntropy是LogSoftmax+NLLLoss:')
print('ce:',F.cross_entropy(x,y,ignore_index=-100))
print('ls+nll:',F.nll_loss(z,y,ignore_index=-100))
print()
print('说明ignore_index的作用:')