相关文章推荐
有腹肌的卡布奇诺  ·  js ...·  2 月前    · 
独立的书签  ·  docker ...·  2 天前    · 

在PyTorch中,nll_loss的输入是什么?

1 人关注

我正在看这里的教程。 https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

import torch.nn.functional as F
loss = F.nll_loss(output, target)

在上述两行代码中,"目标 "到底是什么?他们为目标加载了数据集,但从未讨论过它到底是什么。文档也很难理解。

python
pytorch
JobHunter69
JobHunter69
发布于 2019-07-27
1 个回答
Anubhav Singh
Anubhav Singh
发布于 2019-07-27
已采纳
0 人赞同

Check yourself by running below code:

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
        batch_size=1, shuffle=True)
for data, target in test_loader:
    print(data, target)
    break

这里,data基本上是一个灰度的MNIST图像,target09之间的标签。

因此,在loss = F.nll_loss(output, target)中,output是模型预测(模型在给出图像/数据时的预测),target是给定图像的实际标签。

此外,在上述例子中,检查以下几行。

output = model(data) # shape [1, 10]
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
# If the initial prediction is wrong, don't bother attacking, just move on
if init_pred.item() != target.item():
   continue