materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
233 stars 57 forks source link

How to make a multi-target regression with m3gnet model? #186

Closed KirillKulaev closed 7 months ago

KirillKulaev commented 8 months ago

Hi, I tried to repeat the example “Training a M3GNet Formation Energy Model with PyTorch Lightning.ipynb”, but I want to train this model to predict spectra as a vector, and when I try to train m3gnet model, I get the error, although I put the ntarget parameter.

# setup the architecture of MEGNet model
model = M3GNet(
        element_types=elem_list,
        is_intensive=True,
        readout_type="set2set",
        ntarget=66,
        )
# setup the MEGNetTrainer
lit_module = ModelLightningModule(model=model)

logger = CSVLogger("logs", name="M3GNet_training")
trainer = pl.Trainer(max_epochs=20, accelerator="gpu", logger=logger)
trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/matgl/graph/data.py", line 31, in collate_fn
    labels = torch.tensor([next(iter(d.values())) for d in labels], dtype=matgl.float_th)  # type: ignore
ValueError: only one element tensors can be converted to Python scalars

https://colab.research.google.com/drive/1L05611HYB6UMb380xYWXp9nBZL51iHYc#scrollTo=6crRrc29Dawl

kenko911 commented 7 months ago

Hi @KirillKulaev, thanks for reporting the error. I would like to clarify a bit on your regression problem. I think you are referring to the multiple values per target for each structure/graph but not another way around i.e. single value but multiple targets for each structure/graph. Regarding the multiple values per target, the framework of collate_fn needs to be modified so that it works for your purpose. I will push the fix soon. Thanks!

KirillKulaev commented 7 months ago

Hi @KirillKulaev, thanks for reporting the error. I would like to clarify a bit on your regression problem. I think you are referring to the multiple values per target for each structure/graph but not another way around i.e. single value but multiple targets for each structure/graph. Regarding the multiple values per target, the framework of collate_fn needs to be modified so that it works for your purpose. I will push the fix soon. Thanks!

Thank you very much for your answer!

kenko911 commented 7 months ago

Hi @KirillKulaev, I just pushed the fix for the collate_fn and now it should work. Please pull the latest version of MatGL and take the unit test of model training (tests/utils/test_training.py) for multiple values per target with M3GNet as a reference to modify your script. I would like to stress that this is an experimental feature and I am not sure how good is the accuracy of the M3GNet model for your purpose.

KirillKulaev commented 7 months ago

Hi @KirillKulaev, I just pushed the fix for the collate_fn and now it should work. Please pull the latest version of MatGL and take the unit test of model training (tests/utils/test_training.py) for multiple values per target with M3GNet as a reference to modify your script. I would like to stress that this is an experimental feature and I am not sure how good is the accuracy of the M3GNet model for your purpose.

Thank you very much, everything works well