RobDHess / Steerable-E3-GNN

E(3) Steerable Graph Neural Network
MIT License
102 stars 19 forks source link

InstanceNorm not working #9

Open pimdh opened 11 months ago

pimdh commented 11 months ago

Hi! Thanks for this library :) I'm trying to use InstanceNorm and it appears there's a bug. When I run the following

irreps = BalancedIrreps(3, 20)
norm = InstanceNorm(irreps)
x = torch.randn(9, 20)
batch = torch.zeros(9, dtype=torch.long)
norm(x, batch)

I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 8
      6 x = torch.randn(9, 10)
      7 batch = torch.zeros(9, dtype=torch.long)
----> 8 norm(x, batch)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/segnn/segnn/instance_norm.py:85, in InstanceNorm.forward(self, input, batch)
     82 # For scalars first compute and subtract the mean
     83 if ir.l == 0:
     84     # Compute the mean
---> 85     field_mean = global_mean_pool(field, batch).reshape(-1, mul, 1)  # [batch, mul, 1]]
     86     # Subtract the mean
     87     field = field - field_mean[batch]

File /usr/local/lib/python3.8/dist-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
     61     return x.mean(dim=dim, keepdim=x.dim() <= 2)
     62 size = int(batch.max().item() + 1) if size is None else size
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:81, in scatter(src, index, dim, dim_size, reduce)
     78 count.scatter_add_(0, index, src.new_ones(src.size(dim)))
     79 count = count.clamp(min=1)
---> 81 index = broadcast(index, src, dim)
     82 out = src.new_zeros(size).scatter_add_(dim, index, src)
     84 return out [/](https://vscode-remote+ssh-002dremote-002bgatr.vscode-resource.vscode-cdn.net/) broadcast(count, out, dim)

File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:21, in broadcast(src, ref, dim)
     19 size = [1] * ref.dim()
     20 size[dim] = -1
---> 21 return src.view(size).expand_as(ref)

RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 1.  Target sizes: [9, 10, 1].  Tensor sizes: [1, 9, 1]

It appears this is because global_mean_pool from pytorch geometric does not support more than 2 dimensions. The solution could be to replace global_mean_pool(field, batch) with global_mean_pool(field.view(-1, mul), batch).

Cheers, Pim

RobDHess commented 10 months ago

Hi,

Thanks for alerting us to this, I will fix it in the near future.

Cheers,

Rob