I am trying to split a regular field into all of its field components. According to GeometricTensor.split this can be done via x.split(None), which however errors out:
import torch
import e2cnn
r2_act = e2cnn.gspaces.Rot2dOnR2(N=12)
feat_type_in = e2cnn.nn.FieldType(r2_act, [r2_act.trivial_repr])
feat_type_out = e2cnn.nn.FieldType(r2_act, [r2_act.regular_repr])
conv = e2cnn.nn.R2Conv(feat_type_in, feat_type_out, kernel_size=3, padding=1)
x = e2cnn.nn.GeometricTensor(torch.randn((1,1,32,32)), feat_type_in)
x = conv(x)
print(x.tensor.shape)
x.split(None)
>>> AssertionError: Error! "breaks" must be an increasing list of positive indexes
Thanks a lot for spotting this bug! the documentation and the behaviour were indeed not really clear.
It should work now.
I will also upload the update docs soon.
Hi,
First off, thanks a lot for this nice package!
I am trying to split a regular field into all of its field components. According to
GeometricTensor.split
this can be done viax.split(None)
, which however errors out:Am I doing something wrong here?