Closed terry-r123 closed 2 years ago
@terry-r123,
Thank you for your kind words as well as for your interest in our work! I took a look at this issue you reported, and oddly enough it does appear as though the input batches to the train/val/test_step()
functions have changed shape from what I would have expected. I believe it is possible this batch parsing logic did not get updated from a previous version of the code. Nonetheless, I have updated the main
branch of this project to now use the latest batch parsing logic. If you would not mind, please feel free to run git pull origin main
in your copy of the repository to pull down my latest changes. I hope this fixes the issue you are seeing!
@amorehead Thank you for replying and updating! I also tried to change 'graph1, graph2 = train_batch['graph1'], train_batch['graph2']' to ‘graph1, graph2, examples_list, filepaths = train_batch[0], train_batch[1], train_batch[2], train_batch[3]’ before, and it works!
Thanks for great DeepInteract! When I run the line:
python3 lit_model_train.py --lr 1e-3 --weight_decay 1e-2
I get the following:
Traceback (most recent call last): File "lit_model_train.py", line 223, in
main(args)
File "lit_model_train.py", line 174, in main
trainer.fit(model=model, datamodule=picp_data_module)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
self._run(model)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 917, in _run
self._dispatch()
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 985, in _dispatch
self.accelerator.start_training(self)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 995, in run_stage
return self._run_train()
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1114, in _run_sanity_check
self._evaluation_loop.run()
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 111, in advance
output = self.evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 158, in evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 211, in validation_step
return self.training_type_plugin.validation_step(step_kwargs.values())
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 392, in validation_step
return self.model(args, kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 624, in forward
output = self.module(*inputs, kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, *kwargs)
File "/home/user/miniconda/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
output = self.module.validation_step(inputs, kwargs)
File "/ryc/DeepInteract/project/utils/deepinteract_modules.py", line 1923, in validation_step
graph1, graph2 = batch['graph1'], batch['graph2']
File "/home/user/miniconda/lib/python3.8/site-packages/dgl/heterograph.py", line 2152, in getitem
raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
dgl._ffi.base.DGLError: Invalid key "graph1". Must be one of the edge types.
Exception in thread Thread-2:
Traceback (most recent call last):
File "/home/user/miniconda/lib/python3.8/threading.py", line 932, in _bootstrap_inner