mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
251 stars 21 forks source link

Point Clouds Example #10

Closed whitead closed 3 years ago

whitead commented 3 years ago

I'm trying to work with molecules and would like a permutation group with an SO(3) group. This is discussed a little bit in the docs, but I'm having trouble getting it to implement correctly. I would like to have 5 features per input point and 5 coordinates and I would to output a single coordinate. An example would be computing dipole moment for a molecule with 5 atoms. I tried writing like this:

# make product group
from emlp.reps import V
from emlp.groups import SO, S
G = SO(3) * S(5)
#     element + coorindates
Vin = V(S(5)) + V(G)
Vout = V(SO(3))
print(Vin.size(), Vout.size())
# make model
model = emlp.nn.EMLP(Vin, Vout, group=G)
input_point = np.random.randn(Vin.size())
model(input_point).shape

but the output from the model is (15,) instead of (3,) like I would expect. Thank you!

mfinzi commented 3 years ago

Hi Andrew,

This is a great question. There is a way to do it in the existing library, but you have to write your own representation of V(S(5)) and V(SO(3)) for the group SO(3) * S(5) which is a bit cumbersome. I'll send an example of this way of doing things tomorrow.

Though as you've noticed with your syntax and the examples from Combining Representations from Different Groups there is a more streamlined (and computationally efficient) way of getting the equivariant linear maps when you have product groups like this, but it's not yet supported with emlp.nn.EMLP network (on the todo list). However, it shouldn't be too hard to implement and I plan to do it soon (within 1-2 weeks).

whitead commented 3 years ago

Hi @mfinzi,

Thanks for the answer! Looking forward to your example.

mfinzi commented 3 years ago

Here we go. So with some effort, one can actually implement restricted representations and other structures with product groups by defining a new representation without using the specialized functionality described in the combining representations tutorial.

import emlp
from emlp.reps import V,T,Rep
from emlp.groups import Z,S,SO,Group
import numpy as np

class ProductSubRep(Rep):
    def __init__(self,G,subgroup_id,size):
        """   Produces the representation of the subgroup of G = G1 x G2
              with the index subgroup_id in {0,1} specifying G1 or G2.
              Also requires specifying the size of the representation given by G1.d or G2.d """
        self.G = G
        self.index = subgroup_id
        self._size = size
    def __str__(self):
        return "V_"+str(self.G).split('x')[self.index]
    def size(self):
        return self._size
    def rho(self,M): 
        # Given that M is a LazyKron object, we can just get the argument
        return M.Ms[self.index]
    def drho(self,A):
        return A.Ms[self.index]
    def __call__(self,G):
        # adding this will probably not be necessary in a future release,
        # necessary now because rep is __call__ed in nn.EMLP constructor
        assert self.G==G
        return self
G1,G2 = SO(3),S(5)
G = G1 * G2

VSO3 = ProductSubRep(G,0,G1.d)
VS5 = ProductSubRep(G,1,G2.d)

Vin = VS5 + V(G)
Vout = VSO3
print(f"Vin: {Vin} of size {Vin.size()}")
print(f"Vout: {Vout} of size {Vout.size()}")
# make model
model = emlp.nn.EMLP(Vin, Vout, group=G)
input_point = np.random.randn(Vin.size())*10
print(f"Output shape: {model(input_point).shape}")

Vin: V_S(5)+V of size 20 Vout: V_SO(3) of size 3 Output shape: (3,)

And we can double check that the model is equivariant:

# Test the equivariance
def rel_err(a,b):
    return np.sqrt(((a-b)**2).sum())/(np.sqrt((a**2).sum())+np.sqrt((b**2).sum()))

from emlp.reps.linear_operators import LazyKron
lazy_G_sample = LazyKron([G1.sample(),G2.sample()])

out1 = model(Vin.rho(lazy_G_sample)@input_point)
out2 = Vout.rho(lazy_G_sample)@model(input_point)
print(f"Equivariance Error: {rel_err(out1,out2)}")

Equivariance Error: 5.094278208161995e-07

whitead commented 3 years ago

Wow that is quite satisfying! That solves my problem completely. Thank you!