BioinfoMachineLearning / DeepInteract

A geometric deep learning pipeline for predicting protein interface contacts. (ICLR 2022)
https://zenodo.org/record/6671582
GNU General Public License v3.0
62 stars 11 forks source link

[BUG?] Invalid key "graph1". Must be one of the edge types. #7

Closed terry-r123 closed 2 years ago

terry-r123 commented 2 years ago

Thanks for great DeepInteract! When I run the line:

python3 lit_model_train.py --lr 1e-3 --weight_decay 1e-2

I get the following:

Traceback (most recent call last): File "lit_model_train.py", line 223, in main(args) File "lit_model_train.py", line 174, in main trainer.fit(model=model, datamodule=picp_data_module) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit self._run(model) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 917, in _run self._dispatch() File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 985, in _dispatch self.accelerator.start_training(self) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training self.training_type_plugin.start_training(trainer) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training self._results = trainer.run_stage() File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 995, in run_stage return self._run_train() File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_train self._run_sanity_check(self.lightning_module) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1114, in _run_sanity_check self._evaluation_loop.run() File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run self.advance(*args, kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance dl_outputs = self.epoch_loop.run( File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run self.advance(*args, kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 111, in advance output = self.evaluation_step(batch, batch_idx, dataloader_idx) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 158, in evaluation_step output = self.trainer.accelerator.validation_step(step_kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 211, in validation_step return self.training_type_plugin.validation_step(step_kwargs.values()) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 392, in validation_step return self.model(args, kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 624, in forward output = self.module(*inputs, kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward output = self.module.validation_step(inputs, kwargs) File "/ryc/DeepInteract/project/utils/deepinteract_modules.py", line 1923, in validation_step graph1, graph2 = batch['graph1'], batch['graph2'] File "/home/user/miniconda/lib/python3.8/site-packages/dgl/heterograph.py", line 2152, in getitem raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key)) dgl._ffi.base.DGLError: Invalid key "graph1". Must be one of the edge types. Exception in thread Thread-2: Traceback (most recent call last): File "/home/user/miniconda/lib/python3.8/threading.py", line 932, in _bootstrap_inner

amorehead commented 2 years ago

@terry-r123,

Thank you for your kind words as well as for your interest in our work! I took a look at this issue you reported, and oddly enough it does appear as though the input batches to the train/val/test_step() functions have changed shape from what I would have expected. I believe it is possible this batch parsing logic did not get updated from a previous version of the code. Nonetheless, I have updated the main branch of this project to now use the latest batch parsing logic. If you would not mind, please feel free to run git pull origin main in your copy of the repository to pull down my latest changes. I hope this fixes the issue you are seeing!

terry-r123 commented 2 years ago

@amorehead Thank you for replying and updating! I also tried to change 'graph1, graph2 = train_batch['graph1'], train_batch['graph2']' to ‘graph1, graph2, examples_list, filepaths = train_batch[0], train_batch[1], train_batch[2], train_batch[3]’ before, and it works!