divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.82k stars 281 forks source link

Errors in dimensions #194

Closed Breaddsmall closed 1 year ago

Breaddsmall commented 1 year ago

I tried to use spherenet and schnet on my customized material datasets, and I met these problems during my training.

Traceback (most recent call last):
  File "/home/ubuntu/baselines/scripts/run_sphere.py", line 214, in <module>
    main()
  File "/home/ubuntu/baselines/scripts/run_sphere.py", line 211, in main
    run_spherenet()
  File "/home/ubuntu/baselines/scripts/run_sphere.py", line 207, in run_spherenet
    run3d.run(torch.device('cuda:0'), train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation,
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/run.py", line 71, in run
    train_mae = self.train(model, optimizer, train_loader, energy_and_force, p, loss_func, device)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/run.py", line 136, in train
    out = model(batch_data)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/spherenet/spherenet.py", line 299, in forward
    u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v, batch) #scatter(v, batch, dim=0)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch_scatter/scatter.py", line 152, in scatter
    return scatter_sum(src, index, dim, out, dim_size)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum
    index = broadcast(index, src, dim)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch_scatter/utils.py", line 12, in broadcast
    src = src.expand(other.size())
RuntimeError: The expanded size of the tensor (460) must match the existing size (461) at non-singleton dimension 0.  Target sizes: [460, 1].  Tensor sizes: [461, 1]
 File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/run.py", line 71, in run
    train_mae = self.train(model, optimizer, train_loader, energy_and_force, p, loss_func, device)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/run.py", line 136, in train
    out = model(batch_data)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/schnet/schnet.py", line 166, in forward
    v = update_v(v,e, edge_index)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dig/lib/python3.10/site-packages/dig/threedgraph/method/schnet/schnet.py", line 59, in forward
    return v + out
RuntimeError: The size of tensor a (447) must match the size of tensor b (446) at non-singleton dimension 0

Does this kind of problem occurs because of some of the nodes are not connected with others by the given cutoff? If so, how should I fix this (other than setting a larger cutoff)

Breaddsmall commented 1 year ago

I found the problem is caused by the last node of a batch not connected to any other nodes in the batch, which means its index isn't in the adjacency matrix. Therefore, during the process of torch.scatter(), it is not aggregated, leading to the mismatch in the dimension.

limei0307 commented 1 year ago

Hi @Breaddsmall, thanks for your interest in our work, and sorry for my late reply! Yes, isolated nodes may cause some problems. You can try to remove isolated nodes or add self-loop to solve this issue.

In addition, in our implementation, for the scatter function, we only give the src, index, and dim as inputs, you can also try to add additional inputs like out and dim_size. This can also solve this issue.

Best

pearl-rabbit commented 1 year ago

Hello, I am not sure how to use the above solution without changing the original dimensions. Can you provide some examples?

limei0307 commented 1 year ago

Hi @pearl-rabbit,

About the scatter function I mentioned reviously, could you try to replace line 203-205 with

def forward(self, e, i, num_nodes):
        _, e2 = e
        v = scatter(e2, i, dim=0, dim_size=num_nodes)

and change line 301 to v = update_v(e, i, num_nodes) to see if this can solve your problem?