coleygroup / molpal

active learning for accelerated high-throughput virtual screening
MIT License
159 stars 36 forks source link

Multi-task MPNN target sizes do not match #9

Closed miquelduranfrigola closed 2 years ago

miquelduranfrigola commented 2 years ago

Hi! Thanks for a wonderful repository.

I am trying to train a multitask regression using your MPNN class:

model = MPNN(ncpu=12, num_tasks=2) I am testing it with a target numpy array of shape (10000, 2)

When I run model.train(smis, targets) I get the warning:

anity check ... /home/mduranfrigola/github/ersilia-os/zaira-chem-lite/zairachemlite/model/molpal/models/mpnn/ptl/model.py:36: UserWarning: Using a target size (torch.Size([50, 1, 2])) that is different to the input size (torch.Size([50, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  "rmse": lambda X, Y: torch.sqrt(F.mse_loss(X, Y, reduction="none")),
Training:   0%|                                                                                                  | 0/50 [00:00<?, ?epoch/s/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/torch/nn/modules/loss.py:520: UserWarning: Using a target size (torch.Size([50, 1, 2])) that is different to the input size (torch.Size([50, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

and then the following error, correspondingly:

home/mduranfrigola/github/ersilia-os/zaira-chem-lite/zairachemlite/model/molpal/models/mpnn/ptl/model.py:36: UserWarning: Using a target size (torch.Size([1, 1, 2])) that is different to the input size (torch.Size([1, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  "rmse": lambda X, Y: torch.sqrt(F.mse_loss(X, Y, reduction="none")),
Traceback (most recent call last):
  File "__init__.py", line 33, in <module>
    mdl.train(smiles, targets)
  File "/home/mduranfrigola/github/ersilia-os/zaira-chem-lite/zairachemlite/model/molpal/models/mpnmodels.py", line 181, in train
    trainer.fit(lit_model, train_dataloader, val_dataloader)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 869, in run_train
    self.train_loop.run_training_epoch()
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 576, in run_training_epoch
    self.trainer.run_evaluation(on_epoch=True)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 988, in run_evaluation
    self.evaluation_loop.evaluation_epoch_end(outputs)
  File "/home/mduranfrigola/miniconda3/envs/molpal/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 213, in evaluation_epoch_end
    model.validation_epoch_end(outputs)
  File "/home/mduranfrigola/github/ersilia-os/zaira-chem-lite/zairachemlite/model/molpal/models/mpnn/ptl/model.py", line 88, in validation_epoch_end
    val_loss = torch.cat(outputs).mean()
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 50 but got size 1 for tensor number 40 in the list.
                                                                                                                                          Exception ignored in: <function tqdm.__del__ at 0x7fa07b9a1790>                                                                             
Traceback (most recent call last):
  File "/home/mduranfrigola/.local/lib/python3.8/site-packages/tqdm/std.py", line 1124, in __del__
  File "/home/mduranfrigola/.local/lib/python3.8/site-packages/tqdm/std.py", line 1337, in close
  File "/home/mduranfrigola/.local/lib/python3.8/site-packages/tqdm/std.py", line 1516, in display
  File "/home/mduranfrigola/.local/lib/python3.8/site-packages/tqdm/std.py", line 1127, in __repr__
  File "/home/mduranfrigola/.local/lib/python3.8/site-packages/tqdm/std.py", line 1477, in format_dict
TypeError: cannot unpack non-iterable NoneType object

Is there anything I am doing wrong?

Many thanks!

davidegraff commented 2 years ago

Hi @miquelduranfrigola,

We currently don’t support multitask optimization in molpal, so setting the number of tasks to anything other than 1 will break the code. How are you currently using the code?

miquelduranfrigola commented 2 years ago

Hi @davidegraff

Thanks for the fast reply! I was interested in the fact that you are using PyTorch Lightning, so my plan was to use the MPNN class of MolPal as a drop-in-replacement for my ChemProp multi-task regression models, which are typically slow. I hope this makes sense?

Thanks! M

davidegraff commented 2 years ago

yeah so there is a bug right now in molpal.models.mpnmodels.py#L191. This block:

def make_datasets(
    self, xs: Iterable[str], ys: Sequence[float]
) -> Tuple[MoleculeDataset, MoleculeDataset]:
    """Split xs and ys into train and validation datasets"""

    data = MoleculeDataset([
        MoleculeDatapoint(smiles=[x], targets=[y])
        for x, y in zip(xs, ys)
    ])
    ...

assumes that ys is an array of single-task target values, so it's a vector of length n rather than an array of shape n x 1. If you iterate through an n x m array and then wrap it in a list (as in targets=[y]), then the target of each point is a list of m floats rather than m separate targets.

I just fixed this by checking the target shape in MPNN.train(#10). let me know if this problem persists after the latest commit

miquelduranfrigola commented 2 years ago

Hi @davidegraff it works nicely now with the latest commit! Many thanks for your help!