mariogeiger / nequip-jax

19 stars 3 forks source link

E3NN equivalent TP in Nequip Jax? #3

Open ipcamit opened 4 months ago

ipcamit commented 4 months ago

I am trying to recreate my nequip-torch model to nequip-jax. While the number of parameters matches exactly when I create the model, the numerical values do not. SO I am trying to assemble the layers myself one by one. One of the issue is that I could not figure out how to express the equivalent tensor product in nequip-jax In torch I have following tensor product

e3nn.o3.TensorProduct(
    "32x0e",
    "1x0e+1x1o",
    "32x0e+32x1o",
    [(0, 0, 0, 'uvu', True), (0, 1, 1, 'uvu', True)],
    shared_weights=False,
    internal_weights=False,
)

Where I provide weights manually to weight each path (?). But in Nequip-jax I found the following equivalent code:

messages = e3nn.concatenate(
        [
            messages.filter(output_irreps + "0e"),
            e3nn.tensor_product(
                messages,
                e3nn.spherical_harmonics(
                    [l for l in range(1, self.max_ell + 1)],
                    vectors,
                    normalize=True,
                    normalization="component",
                ),
                filter_ir_out=output_irreps + "0e",
            ),
        ]
    ).regroup()

I dont understand:

  1. How to give weights in e3nn_jax.tensor_product?
  2. Should it be [l for l in range(0, self.max_ell + 1)]? As in original Nequip, edge attributes go from 0-l?

Thanks

mariogeiger commented 4 months ago
  1. when using e3nn-jax, instead of giving weights to the tp you simply follow the tp with a jax.flax.Linear layer. This is equivalent
  2. yes you can though this code you pasted is doing a slight optimization, it's concatenating messages with tp of messages with sh (l>0) which is equivalent to tp messages with sh(>=0)