yandex-research / rtdl

Research on Tabular Deep Learning: Papers & Packages
Apache License 2.0
888 stars 98 forks source link

How to get the probablity in each multiclass? #49

Closed jerronl closed 1 year ago

jerronl commented 1 year ago

other than the max bucket from argmax, can we also have the probabilities for each class? currently the prediction often have a lot negative values and I don't know what would be the right way to convert them to probabilities.

Yura52 commented 1 year ago

The predictions are logits. To convert them to probabilities, use softmax:

import torch.nn.functional as F

logits = model(x)
probabilities = F.softmax(logits, dim=-1)