mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
251 stars 21 forks source link

2V>>Scalar representation is all zero? #14

Closed topoliu closed 2 years ago

topoliu commented 2 years ago

import jax import jax.numpy as jnp from emlp.reps import V,Scalar from emlp.groups import SO import numpy as np

W =V(SO(3)) rep = 2*W P = (rep>>Scalar).equivariant_projector() applyP = lambda v: P@v

P.to_dense()

DeviceArray([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]], dtype=float32)

I am running a toy example, two vector as input, a scalar as output if I rotate the vector, the output scalar should not change.

But I get P matrix is all zero!

mfinzi commented 2 years ago

Hi @topoliu, This is as expected since there are no rotation equivariant linear maps directly from a vector (or a pair of them) to a scalar. You can think of maps from vectors to scalars as mere vectors, and there are no rotation invariant vectors in 3D (except for the 0 vector).

However don't despair, there are many nonlinear equivariant maps and EMLP will find them by mapping to feature tensors of higher order as well as using equivariant nonlinearities.

To see how this works, let's imagine the first feature representation of EMLP is 30*Scalar+10*W+3*W**2+W**3 The equivariant layer from the first 2*W to this feature map has 32 independent degrees of freedom

(2*W>>30*Scalar+10*W+3*W**2+W**3).equivariant_basis().shape

Out: (684, 32)

And calling vis to visualize this basis, the layer looks like this: image

Showing that with these higher order representations and nonlinearities, EMLP can in fact fit your functions, let's take the example target function is f: 2V->Scalar, f(u,v) = u^Tv-||v||^3, EMLP can fit this function without trouble.

import emlp
from emlp.reps import V,T,Scalar
from emlp.groups import SO
import objax

W = V(SO(3))
repin = 2*W
repout = Scalar

model = emlp.nn.EMLP(repin,repout,SO(3))

import numpy as np
import jax.numpy as jnp
from jax import device_put

u,v = np.random.randn(2,100,3)
X = np.concatenate([u,v],-1)
Y = ((u*v).sum(-1) - jnp.linalg.norm(v,axis=1)**3)[:,None]
Y /= Y.std()
Y = device_put(Y)
X = device_put(X)
opt = objax.optimizer.Adam(model.vars())

@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def loss(x, y):
    yhat = model(x)
    return ((yhat-y)**2).mean()

grad_and_val = objax.GradValues(loss, model.vars())

@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def train_op(x, y, lr):
    g, v = grad_and_val(x, y)
    opt(lr=lr, grads=g)
    return v

import matplotlib.pyplot as plt
losses = [train_op(X,Y,3e-3) for _ in range(200)]
plt.plot(losses)
plt.yscale('log')
plt.ylabel("MSE")

image

topoliu commented 2 years ago

Hello Finzi:

Thanks for you quick and detailed response. I got a little confused about the scenario: I have n 3 dim vectors as features and a scalar as targets. at least, the angle and distance of two vectors should not change under SO(3).

At a more realistic scenario, molecular property prediction, the atom coordinates are 3 dim vectors, and the property is a scalar. (nV>>Scalar) is always zero.

Thanks

mfinzi commented 2 years ago

Hi @topoliu, As you mention, angles and distance between two vectors are invariants under SO(3). EMLP can fit these functions just fine, to see for yourself just rerun the above example with the targets Y = jnp.linalg.norm(u-v) or Y = (u*v).sum(-1)/np.sqrt((u*u).sum(-1)*(v*v).sum(-1)).

Again, (nV>>Scalar) is zero, but EMLP(nV,Scalar) is not, since it heavily leverages the bilinear layer to compute inner products (which for O(3) and scalar outputs is essentially all you need, see Scalars are universal: Gauge-equivariant machine learning, structured like classical physics). You can form these maps with EMLP, although if you also want permutation equivariance between the vectors then E(n)-GNN is the practical way to go (even if it can be replicated with some effort via EMLP).

I hope this clears up the confusion. The key point is that there is a difference between the equivariant linear maps (found with (nV>>Scalar).equivariant_basis()) and the nonlinear equivariant multilayer perceptron that you get with EMLP(nV,Scalar,SO(3)) even though EMLP makes use of the linear maps within its layers.