REINS-SJTU / EGAT-pytorch

Code of model EGAT
0 stars 2 forks source link

RuntimeError: Problem with tensor shapes #3

Open Akulen opened 4 months ago

Akulen commented 4 months ago

While running the code, it almost immediately fails with the following error:

Traceback (most recent call last):
  File "/home/akulen/EGAT-pytorch/main.py", line 28, in <module>
    trainer.fit(model)
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/trainer/amlsim.py", line 68, in validation_step
    y_pred = self.model(data)
             ^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/model/net.py", line 293, in forward
    (x, e), (x_o, e_o), (x_hidden, e_hidden) = layer(x, e_idx, e, x_hidden, e_hidden)
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/model/net.py", line 102, in forward
    h, h_o, x_hidden = self.update_vertex(x, edge_index, edge_attr, x_hidden=x_hidden)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/model/net.py", line 85, in update_vertex
    h = self._vertex_module(h, edge_index, e)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/model/node.py", line 44, in forward
    h = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=h, e=e, x_o=x, e_o=edge_attr)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/.cache/pyg/message_passing/model.node_EGAT_split_propagate.py", line 304, in propagate
    out = self.aggregate(
          ^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/conv/message_passing.py", line 612, in aggregate
    return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/experimental.py", line 117, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 136, in __call__
    raise e
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 128, in __call__
    return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/basic.py", line 22, in forward
    return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/nn/aggr/base.py", line 182, in reduce
    return scatter(x, index, dim, dim_size, reduce)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py", line 74, in scatter
    index = broadcast(index, src, dim)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/akulen/EGAT-pytorch/venv/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py", line 179, in broadcast
    return src.view(size).expand_as(ref)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (8) must match the existing size (79668) at non-singleton dimension 1.  Target sizes: [79668, 8, 48].  Tensor sizes: [1, 79668, 1]

I've looked around a little, and it seems the default add aggregator of pyg.MessagePassing expects the second-to-last dimension to be the main dimension, so as the code is using 8 heads, the second-to-last dimension is the head dimension, and not the main one (the edge dimension, 79668)