QUVA-Lab / escnn

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
Other
345 stars 45 forks source link

Inferring the flip in an O(2)-equivariant model #19

Closed colobas closed 1 year ago

colobas commented 1 year ago

Hi,

First of all, congrats on this great piece of work.

I'm working on a model whose inputs are shapes given by their spherical harmonic coefficients. The problem I'm interested in requires O(2)-equivariance in the XY-plane.

I'm essentially interested in learning O(2)-invariant embeddings, plus a function that recognizes the element of O(2) that transformed the input, in similar fashion to what is done here

I'm able to easily construct the O(2)-invariant part, but I'm however having issues with the group recognition function (which should be O(2)-equivariant).

My idea was to output two steerable vectors which are acted on by O2.irrep(1, 1), restricting them to (0. , -1). I would use one of them to infer the flip, and the other to infer the rotation. However, I don't quite understand how the flipping axis is working. I thought that by restricting the irreps to (0. , 1) I would be fixing the flip axis to the X-axis, and so I expected that flipping my input (along any axis) would translate into a corresponding flip (along the X-axis) of my output steerable vectors. This is not what I'm observing in practice.

Rotations with no flips are working perfectly (i.e. the output vectors come rotated by the appropriate amount), and the invariant part is working for both flips and rotations. But when I flip the input, the output does not come flipped along the X-axis and I'm not able to tell what the actual flip axis is or how to fix it.

What am I missing? Thanks in advance!

colobas commented 1 year ago

I should perhaps mention that the way in which I was thinking of inferring the flip was to take the cross-product of a reference vector (say (1, 0)) and one of the output vectors, and to use the sign of that cross-product as the sign of the flip.

Gabri95 commented 1 year ago

Hey!

Let me understand a bit better the problem you are considering: you are trying to build some kind of MLP taking in input O(3) (or SO(3) ? ) -equivariant vectors (the coefficients of spherical harmonics) and you want to output an element of O(2) in an equivariant way, right?

If that is the case, you need to be careful about how you define O(2), since there are different subgroups of O(3) which are isomorphic to O(2). If we only want to consider 2D rotations in the XY plane, we have two "classes" of subgroups isomorphic to O(2):

I understand that, so far, you are experimenting with the dihedral subgroup of SO(3) but you cannot find the right axis to define reflections. Are you sure what you need is not actually a cone symmetry?

Regarding the second problem of predicting an element of O(2), I'd follow this strategy. You can output 2 O2.irrep(1, 1) vectors which you can interpret as the two columns of a matrix in O(2). The output is guaranteed to be a 2x2 matrix equivariant to O(2), but it's not guaranteed to be an orthgoonal matrix and, therefore, an element of O(2). You can then orthogonalize this matrix by either

I think the second option is probably the easiest one for optimization (not sure how well backpro works through SVD).

Let me know if this answers your question!

I'm also very curious about what you are applying this to! So let me know when you publish your work :)

Best, Gabriele

colobas commented 1 year ago

Oh you’re absolutely right, the cone subgroup is what I’m after.

And the idea to predict an element of O(2) makes sense too.

Thank you!

Nov 17, 2022, 06:54 by @.***:

Hey!

Let me understand a bit better the problem you are considering: you are trying to build some kind of MLP taking in input O(3)-equivariant vectors (the coefficients of spherical harmonics) and you want to output an element of O(2) in an equivariant way, right?

If that is the case, you need to be careful about how you define O(2), since there are different subgroups of O(3) which are isomorphic to O(2). If we only want to consider 2D rotations in the XY plane, we have two "classes" of subgroups isomorphic to O(2):

dihedral symmetry> : the reflection is constructed by a 180 degrees rotation along any axis in the XY plane. The effect is flipping the sign of the Z axis and applying a reflection in the XY plane. This is the > (False,True, -1)> subgroup of O(3) (see > here https://quva-lab.github.io/escnn/api/escnn.group.html#o-3> ) or the > (True, -1)> (or any > (theta, -1)> if you want to change the reflection axis) of SO(3) (see > here https://quva-lab.github.io/escnn/api/escnn.group.html#so-3> ). cone symmetry> : this is a subgroup only of O(3) and a reflection in the XY plane is corresponds to a mirroring in XYZ which doesn't affect the Z coordinate. This is the subgroup > ('cone', -1)> of O(3) (see > here https://quva-lab.github.io/escnn/api/escnn.group.html#o-3> )

I understand that, so far, you are experimenting with the > dihedral> subgroup of SO(3) but you cannot find the right axis to define reflections. Are you sure what you need is not actually a > cone symmetry> ?

Regarding the second problem of predicting an element of O(2), I'd follow this strategy. You can output 2 > O2.irrep(1, 1)> vectors which you can interpret as the two columns of a matrix in O(2). The output is guaranteed to be a 2x2 matrix equivariant to O(2), but it's not guaranteed to be an orthgoonal matrix and, therefore, an element of O(2). You can then orthogonalize this matrix by either

projection to ortrhogonal matrix via SVD, or normalize vector 1 and project vector 2 on orthogonal space of vector1 (+ normalize).

I think the second option is probably the easiest one for optimization (not sure how well backpro works through SVD).

Let me know if this answers your question!

I'm also very curious about what you are applying this to! So let me know when you publish your work :)

Best, Gabriele

— Reply to this email directly, > view it on GitHub https://github.com/QUVA-Lab/escnn/issues/19#issuecomment-1318755916> , or > unsubscribe https://github.com/notifications/unsubscribe-auth/AC7QTI24OJNLV7HQ7AU6GKDWIZBJTANCNFSM6AAAAAARZ2LHUY> . You are receiving this because you authored the thread.> Message ID: > <QUVA-Lab/escnn/issues/19/1318755916> @> github> .> com>

Gabri95 commented 1 year ago

nice, happy this helped! Sorry for the late reply :(

Let me know if this works out! 😄

colobas commented 1 year ago

If I just want to extract the flip component from the matrix you're proposing, I can also just take the determinant of it, right?

To answer the curiosity from your first reply, I'm exploring using this as a representation learning framework for shape representations of nuclear and cellular membranes. I might end up not using the flip-equivariance in my work, because cells are believed to be chiral. But I was trying to get a good understanding of the geometry and math that underlies this awesome package and also for the sake of completeness was trying to implement O2-equivariance

Thanks again!

Gabri95 commented 1 year ago

Yep, the determinant is exactly what you need! (maybe, actually, you can relate the determinant of this matrix with the cross-product of its columns though, I didn't think about it too much though)

Oh that's very interesting! Looking forward to see the full work :)

Best, Gabriele

colobas commented 1 year ago

Hey, sorry to bother you again. I would email but there's a chance this might be useful to someone else so I'll just ask here (and forgive me if this is a naive/silly question): If the output of some module is an array of trivial irrep fields of SO(2), are these expected to also be implicitly invariant to O(2)? I was assuming that not to be the case, but as it turns out when I flip my input I get the same output...

colobas commented 1 year ago

This is how I'm defining my input type:

            SO3 = group.so3_group(max_spharm_band)
            self.G = group.so2_group(max_hidden_band)
            self.gspace = gspaces.no_base_space(self.G)
            _sg_id = SO3._process_subgroup_id((False, -1))
            _hidden_irrep_ids = [(k,) for k in range(1, max_hidden_band + 1)]
            self.in_type = self.gspace.type(
                *[SO3.irrep(ix).restrict(_sg_id)
                  for ix in range(max_spharm_band + 1)]
            )

And this is the output type:

            _out_type = [self.G.trivial_representation]*out_dim + [self.G.irrep(1)]
            self.out_type = self.gspace.type(*_out_type)

The trivial part of that out type isn't changing when I flip my input. (Flipping is accomplished by multiplying the spherical harmonic coefficients corresponding to negative m values, and I've verified this works by plotting the resulting shapes)

colobas commented 1 year ago

I could also specify the input with each pair of m coefficients as a bunch of irrep(m) of SO(2) and the m=0 coefficients as trivial reps of SO(2), but I guess that’s what the .restrict calls are doing for me under the hood, right? Not sure it would make a difference…

Nov 22, 2022 at 21:40 by @.***:

This is how I'm defining my input type:

        SO3 = group.so3_group(max_spharm_band)            self.G = group.so2_group(max_hidden_band)            self.gspace = gspaces.no_base_space(self.G)            _sg_id = SO3._process_subgroup_id((False, -1))            _hidden_irrep_ids = [(k,) for k in range(1, max_hidden_band + 1)]            self.in_type = self.gspace.type(                *[SO3.irrep(ix).restrict(_sg_id)                  for ix in range(max_spharm_band + 1)]            )

And this is the output type:

        _out_type = [self.G.trivial_representation]*out_dim + [self.G.irrep(1)]            self.out_type = self.gspace.type(*_out_type)

The trivial part of that out type isn't changing when I flip my input. (Flipping is accomplished by multiplying the spherical harmonic coefficients corresponding to negative m values, and I've verified this works by plotting the resulting shapes)

— Reply to this email directly, > view it on GitHub https://github.com/QUVA-Lab/escnn/issues/19#issuecomment-1324576537> , or > unsubscribe https://github.com/notifications/unsubscribe-auth/AC7QTIZTDQQ3WIM3BB3JADTWJWU3PANCNFSM6AAAAAARZ2LHUY> . You are receiving this because you are subscribed to this thread.> Message ID: > <QUVA-Lab/escnn/issues/19/1324576537> @> github> .> com>

Gabri95 commented 1 year ago

Hey, that's a great question!

Let me understand a bit better what you are doing though. What kind of operations/layers are you applying on top of the spherical harmonics to map to the output type?

A problem I can see is that if you learn only a linear map from the spherical harmonics to the output SO(2) trivial representation, by Schur's lemma, only the trivial component of each spherical harmonic is used to generate the output.

However, the action of the cone O(2) group you consider also happen to be trivial on those same components. As a result, the output stays invariant to O(2) too. This problem disappears when you add some intermediate layers which include non-linearities. Such layers also leverage information from the other non-invariant compoents of the spherical harmonics, anhd the non-linearity mixes information between components transfoming according to different irreps.

In other words, what you observe should happen only when you have a single linear layer, but disappear when you consider a deeper neural network

Can you confirm that?

Best, Gabriele

colobas commented 1 year ago

That makes sense. Although this mixing only happens for non linearities that affect more than just the norm of the vector fields right? Because I’ve been using just ReLUs for the scalar fields and norm-ReLUs for the vector fields. I guess that would explain why even with non-linearities I’m seeing invariance to O(2)?

I’ll play around with this maybe later today. Thanks again! Nov 25, 2022 at 08:23 by @.***:

Hey, that's a great question!

Let me understand a bit better what you are doing though. What kind of operations/layers are you applying on top of the spherical harmonics to map to the output type?

A problem I can see is that if you learn only a linear map from the spherical harmonics to the output SO(2) trivial representation, by Schur's lemma, only the trivial component of each spherical harmonic is used to generate the output.

However, the action of the cone O(2) group you consider also happen to be trivial on those same components. As a result, the output stays invariant to O(2) too. This problem disappears when you add some intermediate layers which include non-linearities. Such layers also leverage information from the other non-invariant compoents of the spherical harmonics, anhd the non-linearity mixes information between components transfoming according to different irreps.

In other words, what you observe should happen only when you have a single linear layer, but disappear when you consider a deeper neural network

Can you confirm that?

Best, Gabriele

— Reply to this email directly, > view it on GitHub https://github.com/QUVA-Lab/escnn/issues/19#issuecomment-1327675159> , or > unsubscribe https://github.com/notifications/unsubscribe-auth/AC7QTI3RUNEFL24IUF6DLMLWKDRX5ANCNFSM6AAAAAARZ2LHUY> . You are receiving this because you authored the thread.> Message ID: > <QUVA-Lab/escnn/issues/19/1327675159> @> github> .> com>

Gabri95 commented 1 year ago

Yep, that is correct indeed!

This mixing happens only when you use point-wise / Fourier non-linearities. Norm-nonlinearities act on each field independently and, therefore, do not allow any mixing

Best, Gabriele

Gabri95 commented 1 year ago

Hey @colobas , I will close this issue for the moment but feel free to reopen it if you have more questions :)

Best, Gabriele