QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

GeometricTensor.split(None) raises exception #29

Closed maweigert closed 3 years ago

maweigert commented 3 years ago

Hi,

First off, thanks a lot for this nice package!

I am trying to split a regular field into all of its field components. According to GeometricTensor.split this can be done via x.split(None), which however errors out:

import torch
import e2cnn

r2_act = e2cnn.gspaces.Rot2dOnR2(N=12)

feat_type_in = e2cnn.nn.FieldType(r2_act, [r2_act.trivial_repr])
feat_type_out = e2cnn.nn.FieldType(r2_act, [r2_act.regular_repr])

conv = e2cnn.nn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, padding=1)

x = e2cnn.nn.GeometricTensor(torch.randn((1,1,32,32)), feat_type_in)
x = conv(x)

print(x.tensor.shape)

x.split(None)

>>> AssertionError: Error! "breaks" must be an increasing list of positive indexes

Am I doing something wrong here?

Gabri95 commented 3 years ago

Hi @maweigert

Thanks a lot for spotting this bug! the documentation and the behaviour were indeed not really clear. It should work now. I will also upload the update docs soon.

Thanks again! Gabriele