kexinhuang12345 / MolTrans

MolTrans: Molecular Interaction Transformer for Drug Target Interaction Prediction (Bioinformatics)
https://academic.oup.com/bioinformatics/advance-article/doi/10.1093/bioinformatics/btaa880/5929692
BSD 3-Clause "New" or "Revised" License
186 stars 43 forks source link

match batch size of data in model config #14

Closed printomi closed 3 years ago

printomi commented 3 years ago

This fixes the following exception when training with different batch size:

$ python train.py --task biosnap --batch-size 4
Let's use 4 GPUs!
--- Data Preparation ---
Traceback (most recent call last):
  File "train.py", line 206, in <module>
    model_max, loss_history = main()
  File "train.py", line 156, in main
    auc, auprc, f1, logits, loss = test(testing_generator, model_max)
  File "train.py", line 64, in test
    loss = loss_fct(logits, label)
  File "/home/user/anaconda3/envs/MolTrans/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/deepp/anaconda3/envs/MolTrans/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 612, in forward
    return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "/home/user/anaconda3/envs/MolTrans/lib/python3.7/site-packages/torch/nn/functional.py", line 2886, in binary_cross_entropy
    "Please ensure they have the same size.".format(target.size(), input.size())
ValueError: Using a target size (torch.Size([4])) that is different to the input size (torch.Size([16])) is deprecated. Please ensure they have the same size.
kexinhuang12345 commented 3 years ago

Thanks!