AstraZeneca / SubTab

The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"
Apache License 2.0
142 stars 20 forks source link

Using only contrastive loss on MNIST error #13

Closed evgeniael closed 2 years ago

evgeniael commented 2 years ago

Firstly, thank you for this great paper.

I would like to point out that when I am trying to train the model with the MNIST dataset where I keep everything the same as your initial implementation but turn reconstruction and distance loss off (set as false in configuration file) and keep only Contrastive loss as true, I get the following error as soon as my model starts training.

Traceback (most recent call last): File "train.py", line 98, in run_with_profiler(main, config) if config["profile"] else main(config) File "train.py", line 71, in main train(config, ds_loader, save_weights=True) File "train.py", line 34, in train model.fit(data_loader) File "C:\SubTab\src\model.py", line 116, in fit self.update_autoencoder(x_tilde_list, Xorig) File "C:\SubTab\src\model.py", line 233, in update_autoencoder tloss, closs, rloss, zloss = self.joint_loss(z, Xrecon, Xorig) File "C:.conda\envs\modelling-dev\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "C:\SubTab\utils\loss_functions.py", line 144, in forward recon_loss = getMSEloss(xrecon, xorig) if self.options["reconstruction"] else getBCELoss(xrecon, xorig) File "C:\SubTab\utils\loss_functions.py", line 38, in getBCELoss return F.binary_cross_entropy(prediction, label, reduction='sum') / bs File "C:.conda\envs\modelling-dev\lib\site-packages\torch\nn\functional.py", line 3065, in binary_cross_entropy return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) RuntimeError: all elements of input should be between 0 and 1**

After printing the first input I notice that some predictions have small negative values which breaks the calculation of the Binary cross entropy loss.

ex. ensor([[ 0.0011, 0.0684, 0.0063, ..., 0.0269, -0.1042, 0.0020], [-0.0473, -0.0807, 0.0676, ..., -0.0229, -0.1546, 0.0397], [-0.0291, 0.0879, 0.0631, ..., 0.0322, -0.0247, 0.0595], ..., [ 0.0105, 0.1674, 0.0220, ..., -0.0655, -0.1474, -0.1166], [-0.0297, 0.0714, 0.0102, ..., -0.0438, -0.0500, 0.0241], [ 0.0124, 0.1329, 0.0307, ..., -0.0530, -0.0293, 0.0392]], grad_fn=)

Is there something I am doing wrong? Thank you in advance for your help.

talipucar commented 2 years ago

Hi evgeniael,

Thanks for your kind words. Regarding the issue, please look at line 143 in "utils > loss_functions.py"

 recon_loss = getMSEloss(xrecon, xorig) if self.options["reconstruction"] else getBCELoss(xrecon, xorig)

Since you set the "reconstruction: False" in the config, it still computes the reconstruction loss by using cross-entropy loss. To fix this, you can change the code to the following:

recon_loss = getMSEloss(xrecon, xorig) if self.options["reconstruction"] else 0

Hopefully, this will resolve your issue. Keep in mind that , in this case, you would still have a decoder in your network architecture although it would not be trained. You can choose to comment out the decoder in AEWrapper in model_utils.py, or add an "if" condition there.

Please bear in mind that I did run the contrastive-only condition only for an ablation study and I did not think that it would be used in practice since contrastive loss alone may not be effective for data with binary classes (especially if they are highly imbalanced), which is a very common setting in tabular data.

I will close this issue, but please feel free to re-open it if the problem is not resolved, or you have further issues.