Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives

Teams

Q&A for work

Connect and share knowledge within a single location that is structured and easy to search.

Learn more about Teams

I am looking at the tutorial here: https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

import torch.nn.functional as F
loss = F.nll_loss(output, target)

In the above two lines of code, what exactly is "target"? They load the data set for target but never discuss what it is exactly. The documentation is also hard to understand.

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
        batch_size=1, shuffle=True)
for data, target in test_loader:
    print(data, target)
    break

Here, data is basically a grayscaled MNIST image and target is the label between 0 and 9.

So, in loss = F.nll_loss(output, target), output is the model prediction(what the model predicted on giving an image/data) and target is the actual label of the given image.

Furthermore, in the above example, check below lines:

output = model(data) # shape [1, 10]
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
# If the initial prediction is wrong, don't bother attacking, just move on
if init_pred.item() != target.item():
   continue
# Calculate the loss
loss = F.nll_loss(output, target)

In the above code, only those output-target pairs are passed into F.nll_loss loss function, where the model is predicting correctly. In case, it is unable to predict the label correctly, then all the operations(including loss calculation) after that are skipped and it continues with the next example in the test_loader.

I did print out target, but it just gave me something like: "tensor([3])" I'm confused, if target is just the label, why does it have a .item() function? I think nll_loss is taking in an image and also a class object. Isn't that correct? – JobHunter69 Jul 27, 2019 at 18:41 Since I am trying to generate my own target, I am wondering on what to feed into the nll_loss – JobHunter69 Jul 27, 2019 at 18:42 torch.Tensor.item() is used to get a python number from a tensor containing a single value. – Anubhav Singh Jul 27, 2019 at 18:54

Thanks for contributing an answer to Stack Overflow!

  • Please be sure to answer the question. Provide details and share your research!

But avoid

  • Asking for help, clarification, or responding to other answers.
  • Making statements based on opinion; back them up with references or personal experience.

To learn more, see our tips on writing great answers.