Closed futianfan closed 2 years ago
I think this error is caused by torch_geometric version. We use torch_geometric == 1.7.2
in our code. Can you check the version you used in you conda env?
thanks, what is your pytorch version?
it works now. thanks!
the solution is found in https://github.com/pyg-team/pytorch_geometric/issues/4358
modify e2c/dataset.py
class CustomData(Data):
def __cat_dim__(self, key, value,):
class CustomData(Data):
def __cat_dim__(self, key, value, *args, **kwargs):
hi i run the code following your instruction and got the following bugs. do you know how to fix it?
Traceback (most recent call last): File "train.py", line 359, in
main()
File "train.py", line 190, in main
remove_hs=args.remove_hs,
File "/net/sunlab/psunlab1/molecular_data/graphnn/DMCG/confgen/e2c/dataset.py", line 61, in init
super().init(self.folder, transform, pre_transform)
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/in_memory_dataset.py", line 57, in init
super().init(root, transform, pre_transform, pre_filter)
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 88, in init
self._process()
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 171, in _process
self.process()
File "/net/sunlab/psunlab1/molecular_data/graphnn/DMCG/confgen/e2c/dataset.py", line 87, in process
self.process_confgf()
File "/net/sunlab/psunlab1/molecular_data/graphnn/DMCG/confgen/e2c/dataset.py", line 336, in process_confgf
data, slices = self.collate(data_list)
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/in_memory_dataset.py", line 116, in collate
add_batch=False,
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 86, in collate
increment)
File "/nethome/tfu42/.conda/envs/dmcg2/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 128, in _collate
cat_dim = data_list[0].cat_dim(key, elem, stores[0])
TypeError: cat_dim() takes 3 positional arguments but 4 were given