mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
253 stars 21 forks source link

Regression in Example #16

Open whitead opened 2 years ago

whitead commented 2 years ago

I have an example in my book using emlp with the following syntax:

from emlp.reps import V
from emlp.groups import SO, S
# make product group
G = SO(3) * S(5)
# direct sum of  element + coorindates
Vin = V(S(5)) + V(G)
Vout = V(G)
print(Vin.size(), Vout.size())
# make model
model = emlp.nn.EMLP(Vin, Vout, group=G)
input_point = np.random.randn(Vin.size())
model(input_point)

This previously worked and we had discussed it a bit for modifying output in #10. Now in version 1.0.3, this code no longer executes - it gives the following error:

TypeError: Sequential layer[0] <emlp.nn.objax.EMLPBlock object at 0x7ff44436e5e0> dot_general requires contracting dimensions to have the same shape, got [20] and [15].

I was wondering if I need to update the syntax or if there is a bug. Thanks!

mfinzi commented 2 years ago

Sorry just saw this!

Right now I'm not supporting multiple separate groups for the nn.EMLP except as restricted representations or other representations of the product group like in #10. So before the above code may have not errored out and even produced a properly equivariant network but not necessarily even a very sensible one, and this depended on some internal workings that have changed slightly.

I get where you're coming from though as I didn't make this very clear in the documentation with regards to which parts support using multiple groups (the equivariant bases and linear layers) and which do not or may only accidentally (the bilinear layer and automatic internal representation choice in nn.EMLP), and I also mentioned last time that extending this support was in the pipeline but sadly I haven't yet gotten to it.

I would say for the book if you want to show this kind of example it would be better to use this ProductSubRep from #10 , although the syntax is not as clean, it at least is an officially supported pathway of using nn.EMLP. What I can do though to make it nicer and that would probably be useful to others as well is to start adding some of these kinds of derived representations in the repo itself (and with some docs) and then you can just import them.

Cheers

mfinzi commented 2 years ago

I'll also start logging an error in nn.EMLP if it detects multiple different groups are being used (rather than multiple different derived representations of the same product group) so there's no possibility of getting unexpected errors or even failing silently.

whitead commented 2 years ago

Got it, thanks. I'll try to figure out how to incorporate that code into the chapter.