mrdbourke / pytorch-deep-learning

Materials for the Learn PyTorch for Deep Learning: Zero to Mastery course.
https://learnpytorch.io
MIT License
11.07k stars 3.25k forks source link

03_pytorch_computer_vision.ipynb — baseline model #743

Open sascharo opened 11 months ago

sascharo commented 11 months ago

Why don't we use a softmax function to turn the logits into prediction probabilities here?

nickyreinert commented 11 months ago

It's a little confusing, because at some points there's an additional softmax, at some not. I don't know if Daniel explained it somewhen in the early sessions, so let me try:

If you look at the logits and the softmax'ed logits, you will notice that the rank order is the same. That's because softmax basically just "normalizes" the values coming from the model (aka the logits). This returns actual prediction values that sum up to 1.

Argmax just takes the tensor and returns the index of the highest value. This apparently works on predictions but also on not normalized list of floats. So you can leave softmax out.

Daniel probably explained it when talking about Sigmoid and Softmax when introducing binary classification. Maybe it gets clearer when you compare all relevant functions:

binary = torch.arange(0, 1, 0.5)
print(f"A `binary` tensor: {binary}")
print(f"...after sigmoid: {binary.sigmoid()}")
print(f"...after softmax: {binary.softmax(dim=0)}")
print(f"...after argmax: {binary.argmax(dim=0)}")
print(f"...after sigmoid and argmax: {binary.sigmoid().argmax(dim=0)}")

non_binary = torch.arange(0, 1, 0.3)
print(f"A `non-binary` tensor: {non_binary}")
print(f"...after sigmoid: {non_binary.sigmoid()}")
print(f"...after softmax: {non_binary.softmax(dim=0)}")
print(f"...after argmax: {non_binary.argmax(dim=0)}")
print(f"...after softmax and argmax: {non_binary.softmax(dim=0).argmax(dim=0)}")

Remember: Sigmoid is being used for binary classification, where you need to decide if true or false, softmax is for multi class classification. Softmax is "based" on sigmoid and behind both functions there's some mathematic magic.

And this is what you get:

image

//EDITed to make less confusing