yunjey / pytorch-tutorial

PyTorch Tutorial for Deep Learning Researchers
MIT License
30.13k stars 8.12k forks source link

In logistic regression code, why the model is nn.Linear? Doesn't that make it Linear Regression? #147

Closed Gauravtolani closed 5 years ago

Gauravtolani commented 5 years ago

In line 34 of logistic_regression/main.py code, model is set to nn.Linear(input_size, num_classes). I am a bit confused about its working.

Does nn.Linear(input_size, 1) make it a linear regression and nn.Linear(input_size, >1) make it a logistic regression?

PengFoo commented 5 years ago
# Logistic regression model
model = nn.Linear(input_size, num_classes)

# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
criterion = nn.CrossEntropyLoss() 

notice that nn.CrossEntropyLoss() is actually the combination of nn.LogSoftmax() and nn.NLLLoss(), so you do not need an extra logistic/softmax after your linear layer. If you'd like to use the softmax function to make it clear, you have to use nn.NLLLoss() as your criterion instead.