QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

Module export #55

Closed danushv07 closed 2 years ago

danushv07 commented 2 years ago

I'm working on creating an equivariant network for classification task. Based on the example available here, it seems that only Equivariant module can be exported to pytorch. As shown below, an assertion error occurs as torch.nn.Linear is not an Equivariant module.

spc = e2cnn.gspaces.Rot2dOnR2(8)
in_ = e2cnn.nn.FieldType(s, [s.trivial_repr])
out_ = e2cnn.nn.FieldType(s, [s.regular_repr]*16)

net = SequentialModule(
    R2Conv(in_, out_, 3, bias=False),
    ReLU(out_, inplace=True),
    PointwiseMaxPool(out_, kernel_size=2, stride=2),
    GroupPooling(out_),
    torch.nn.Linear(in_channels, out_channels)
)

Is there another way in which, the entire net can be exported to pytorch after training?

Gabri95 commented 2 years ago

Hi @danushv07

Unfortunately, SequentialModule can only contain EquivariantModules, you can not add a torch.nn.Linear module in it.

If I understand correctly, you want:

You can achieve something similar this way:


net = SequentialModule(
    R2Conv(in_, out_, 3, bias=False),
    ReLU(out_, inplace=True),
    PointwiseMaxPool(out_, kernel_size=2, stride=2),
    GroupPooling(out_),
)

# this is the out_type of the last GroupPooling
final_feature_type = net.out_type    

# `out_channels` invariant outputs
output_type = e2cnn.nn.FieldType(s, [s.trivial_repr]*out_channels)

# add the final linear layer as a 1x1 convolution
net.add_module('classifier', R2Conv(final_feature_type, output_type, kernel_size=1)

The final R2Conv will be a 1x1 convolution which just behaves like your torch.nn.Linear, assuming the ouput of PointwiseMaxPool is a 1x1 feature map. You should now be able to export() your model. Note, however, that the output tensor will have shape B x out_channels x 1 x1 rather than B x out_channels, so you may need to do a manual reshaping.

Is my understanding correct? Does this help?

Best, Gabriele

danushv07 commented 2 years ago

Thank you for the prompt reply @Gabri95 . The fore mentioned solution does work well. However, if torch.nn.Linear is required, the entire SequentialModule along with the linear layer can be wrapped in a torch.nn.Module and then used options such as .modules() or .children() can be used to export the required layers.

Gabri95 commented 2 years ago

Hi @danushv07

I am not sure I understood what you mean exactly. Could you share a simple code snippet to illustrate your example?

Thanks, Gabriele