While running the code, it almost immediately fails with the following error:
Traceback (most recent call last):
File "/home/akulen/EGAT-pytorch/main.py", line 28, in <module>
trainer.fit(model)
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1031, in _run_stage
self._run_sanity_check()
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1060, in _run_sanity_check
val_loop.run()
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
return loop_run(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 412, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/trainer/amlsim.py", line 68, in validation_step
y_pred = self.model(data)
^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/model/net.py", line 293, in forward
(x, e), (x_o, e_o), (x_hidden, e_hidden) = layer(x, e_idx, e, x_hidden, e_hidden)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/model/net.py", line 102, in forward
h, h_o, x_hidden = self.update_vertex(x, edge_index, edge_attr, x_hidden=x_hidden)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/model/net.py", line 85, in update_vertex
h = self._vertex_module(h, edge_index, e)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/model/node.py", line 44, in forward
h = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=h, e=e, x_o=x, e_o=edge_attr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/.cache/pyg/message_passing/model.node_EGAT_split_propagate.py", line 304, in propagate
out = self.aggregate(
^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/conv/message_passing.py", line 612, in aggregate
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/experimental.py", line 117, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 136, in __call__
raise e
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 128, in __call__
return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/basic.py", line 22, in forward
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 182, in reduce
return scatter(x, index, dim, dim_size, reduce)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py", line 74, in scatter
index = broadcast(index, src, dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py", line 179, in broadcast
return src.view(size).expand_as(ref)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (8) must match the existing size (79668) at non-singleton dimension 1. Target sizes: [79668, 8, 48]. Tensor sizes: [1, 79668, 1]
I've looked around a little, and it seems the default add aggregator of pyg.MessagePassing expects the second-to-last dimension to be the main dimension, so as the code is using 8 heads, the second-to-last dimension is the head dimension, and not the main one (the edge dimension, 79668)
While running the code, it almost immediately fails with the following error:
I've looked around a little, and it seems the default add aggregator of pyg.MessagePassing expects the second-to-last dimension to be the main dimension, so as the code is using 8 heads, the second-to-last dimension is the head dimension, and not the main one (the edge dimension, 79668)