RobDHess / Steerable-E3-GNN

E(3) Steerable Graph Neural Network
MIT License
102 stars 19 forks source link

error when running qm 9 dataset #10

Open wangxiaoyunNV opened 2 months ago

wangxiaoyunNV commented 2 months ago

Hi, I got the following error when trying to run python3 main.py --dataset=qm9 --epochs=1000 --target=alpha --radius=2 --model=segnn --lmax_h=2 --lmax_attr=3 --layers=7 --subspace_type=weightbalanced --norm=instance --batch_size=128 --gpu=1 --weight_decay=1e-8 --pool=avg


Starting training on a single gpu...
/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
Adjusting learning rate of group 0 to 5.0000e-04.
Training: segnn_qm9_alpha_93972
/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'contains_isolated_nodes' is deprecated, use 'has_isolated_nodes' instead
  warnings.warn(out)
Traceback (most recent call last):
  File "main.py", line 183, in <module>
    train(0, model, args)
  File "/home/nfs/xiaoyunw/az/Steerable-E3-GNN/qm9/train.py", line 67, in train
    out = model(graph).squeeze()
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfs/xiaoyunw/az/Steerable-E3-GNN/models/segnn/segnn.py", line 125, in forward
    x = layer(
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfs/xiaoyunw/az/Steerable-E3-GNN/models/segnn/segnn.py", line 211, in forward
    x = self.feature_norm(x, batch)
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nfs/xiaoyunw/az/Steerable-E3-GNN/models/segnn/instance_norm.py", line 85, in forward
    field_mean = global_mean_pool(field, batch).reshape(-1, mul, 1)  # [batch, mul, 1]]
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_geometric/nn/pool/glob.py", line 58, in global_mean_pool
    return scatter(x, batch, dim=-2, dim_size=size, reduce='mean')
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_scatter/scatter.py", line 156, in scatter
    return scatter_mean(src, index, dim, out, dim_size)
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_scatter/scatter.py", line 41, in scatter_mean
    out = scatter_sum(src, index, dim, out, dim_size)
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum
    index = broadcast(index, src, dim)
  File "/home/nfs/xiaoyunw/miniconda3/envs/segnn/lib/python3.8/site-packages/torch_scatter/utils.py", line 12, in broadcast
    src = src.expand(other.size())
RuntimeError: The expanded size of the tensor (36) must match the existing size (2331) at non-singleton dimension 1.  Target sizes: [2331, 36, 1].  Tensor sizes: [1, 2331, 1]
wangxiaoyunNV commented 2 months ago

the problem is from instance norm, change it to batch norm works.