pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.88k stars 491 forks source link

Potential Issue of CrossEntropyLoss in Tutorial Titanic Data Analysis #1010

Open jensqin opened 2 years ago

jensqin commented 2 years ago

📚 Documentation

In the Titanic Data Analysis tutorial, the use of CrossEntropyLoss is inappropriate. From the documentation of pytorch, CrossEntropyLoss is a combination of LogSoftmax and NLLLoss, so the last Softmax layer of TitanicSimpleNNModel is redundant.

class TitanicSimpleNNModel(nn.Module):

    def forward(self, x):
        lin1_out = self.linear1(x)
        sigmoid_out1 = self.sigmoid1(lin1_out)
        sigmoid_out2 = self.sigmoid2(self.linear2(sigmoid_out1))
        return self.softmax(self.linear3(sigmoid_out2))
    criterion = nn.CrossEntropyLoss()
    # some code

    for epoch in range(num_epochs):    
        output = net(input_tensor)
        loss = criterion(output, label_tensor)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
NarineK commented 2 years ago

Thank you for raising the question, @jensqin! It looks like we could avoid softmax here. The documentation in PyTorch though calls softmax in the examples as well: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

cc: @vivekmig