QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

Feature request: e2cnn equivalent to torch.nn.ModuleList #32

Closed drewm1980 closed 3 years ago

drewm1980 commented 3 years ago

As a minor convenience, it would be nice to have an equivariant equivalent to torch.nn.ModuleList.

The main use cases:

It would not subclass EquivariantModule, since it can't meaningfully conform to most of that API, but it would probably get used mostly in EquivariantModule subclass implementations.

I'm currently trying out the following pattern in my code:

# Inside a bigger class that is itself implementing .export()... with self.up_path also a ModuleList...
def export(self):
        self.eval()
        up_path_exported = torch.nn.ModuleList()
        for module in self.up_path:
            up_path_exported.append(module.export())

If the EquivariantModules were already living in a e2cnn.nn.ModuleList, it would just be:

def export(self):
    self.eval()
    up_path_exported = self.up_path.export()

So a couple lines of boilerplate would be saved, but the discovery aspect is probably worth more... "just build everything using e2cnn 1:1 equivalents, train, then .export()"

Gabri95 commented 3 years ago

Hi @drewm1980 ,

I see your point here. I usually define another export method in the class defining the full model, where then I loop over self.named_modules() to find all equivariant modules to export. Having an EquivariantModuleList would mostly automate this.

I think it is a nice idea, thanks for suggesting it! I will add it later today or tomorrow.

Thanks again! Gabriele

Gabri95 commented 3 years ago

I implemented it as a subclass of torch.nn.ModuleList to reuse its functionalities. The only difference with it is that it should only accept EquivariantModules.

Let me know if this works well for you.

P.S.: I have not pushed the new code on pypi yet, so you should intall it using pip install git+https://github.com/QUVA-Lab/e2cnn

Best, Gabriele