Alcoholrithm / TabularS3L

A PyTorch Lightning-based library for self- and semi-supervised learning on tabular data.
MIT License
18 stars 2 forks source link

Train loss is negative #17

Open akhila-s-rao opened 1 week ago

akhila-s-rao commented 1 week ago

Hi,

When doing first phase training over DAE and VIME (using unlabeled data), I got negative CrossEntropyLoss for the categorical features which resulted in a negative training and validation loss. This lead me to print the predictions for the categorical features and it was a Tensor of size [batch_size, num_categoricals] However this isn't the expected input to CrossEntropyLoss from PyTorch which expects logits in a vector of size num_classes. I don't quite understand in your implementation, how the reconstruction of categorical features work ? Am I missing something here ? As I understand the train_loss should not be negative (which I see for DAE and VIME)

This is what printing the predictions over categorical features looks like (functional/dae.py) cat feature preds: tensor([[ 1.5704], [-0.6245], [ 0.4721], [-0.1746], [ 1.0408], [ 2.5116], [ 1.1048], [-1.1651], [ 4.2188], [ 0.7524], [-0.1088],

Help is appreciated ! Thanks !! :)
P.S. Super useful project !

akhila-s-rao commented 1 week ago

I dug a bit deeper and think I have identified a bug. I found this in the DAE model implementation but it likely exists in the other approaches as well.

The reconstruction_head uses an MLP from commons/ which does not have any output layer. The last layer is linear and this is used for all features.. continuous and categorical. However the categorical features need a #classes sized output layer for each of those categorical features. So when the predictions are made passing X through the encoder and then the reconstruction_head, the categorical features are basically treated like continuous features. However the loss function used is different for each. So when the predictions from the linear output layer are passed to the CrossEntropyLoss function it messes up and gives negative vales as the loss.

I can submit a pull request to fix this bug if you like. Assuming I am not completely wrong here of course hehe.