mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
251 stars 21 forks source link

Parity for O(3) scalars #18

Closed wtianyi closed 2 years ago

wtianyi commented 2 years ago

Hi! First I'd like to say it's a nice library!

Coming from e3nn, I noticed that the scalar rep of O(3) in EMLP always has even parity, i.e. the sign won't flip under reflection, whereas the e3nn library distinguishes 0-order representation with even and odd parity.

How can I have a scalar representation that reflects the two connected components in O(3), or is it a valid question?

mfinzi commented 2 years ago

Hi @wtianyi,

The Scalar object in EMLP is a proper scalar in that it is invariant under all group transformations, but we also provide a PseudoScalar that also is 1 dimensional but transforms under the sign of the determinant of the transformation, so it would flip under a reflection and is the odd parity 0 order representation you mention if I am understanding you correctly.

Here is the code for the PseudoScalar:

import jax.numpy as jnp
from emlp.reps import Rep

class PseudoScalar(Rep):
    def __init__(self,G=None):
        self.G=G
    def __str__(self):
        return "P"
    def rho(self,M):
        sign = jnp.linalg.slogdet(M@jnp.eye(M.shape[0]))[0]
        return sign*jnp.eye(1)
    def __call__(self,G):
        return PseudoScalar(G)
from emlp.groups import O
P = PseudoScalar(O(3))

I = jnp.eye(3)
reflection = -I

print(P.rho(I))
print(P.rho(reflection))

which outputs

[[1.]]
[[-1.]]

Right now this is on the page for constructing new representations as an example, but I will make a new page in the docs exposing some of the commonly used (non tensor) representations that have already been implemented.

I hope this answers your question. Cheers, Marc

wtianyi commented 2 years ago

Thanks for the detailed explanation! Also I'm sorry that I missed the example in the documentation. It will be exciting to see the new page of examples of commonly used non-tensor representations. It might also be helpful to add a pointer to this example from the O(3) group example