atomicarchitects / equiformer_v2

[ICLR 2024] EquiformerV2: Improved Equivariant Transformer for Scaling to Higher-Degree Representations
https://arxiv.org/abs/2306.12059
MIT License
218 stars 27 forks source link

e3nn tensors compatibility issue #9

Open liyy2 opened 8 months ago

liyy2 commented 8 months ago

Hi, I am trying to integrate this with the e3nn package.

For the SO3Embedding class, how can I convert that to an irrep which is compatible with the convention e3nn? My implementation (not sure this is right or not)

    def to_e3nn_embeddings(self):
        from e3nn.io import SphericalTensor
        from e3nn.o3 import Irreps
        embedding = self.embedding.reshape(self.length, -1)

        l = o3.Irreps(str(SphericalTensor(self.lmax_list[-1], 1, -1)).replace('1x', f'{self.num_channels}x'))
        # multiple channels
        return l, embedding
yilunliao commented 8 months ago

Hi @liyy2

I am not familiar with SphericalTensor.

But for tensors in e3nn, they are typically in the form of C_0x0e+C_1x1e... (e.g., 128x0e+128x1e+...). (Let me know if the above one is not clear.)

For EquiformerV2, the tensors are in the form of (0e+1e+..., C) and have shape ((1+L_{max})**2, C). We require the number of channels for each degree to be the same here. (Let me know if that is not clear)

So to convert between these two formats, we can extract all the channels for each degree, flatten them and concatenate all the flattened tensors. Here is an example of converting e3nn tensors to tensors in EquiformerV2:

lmax = 2
num_channels = 128
irreps = o3.Irreps('128x0e+128x1e+128x2e')
tensor_e3nn = irreps.randn(1, -1)  # shape: (1, 128 * (1 + 2) ** 2)

out = []
start_idx = 0
for l in range(lmax + 1):
    length = (2 * l + 1) * num_channels
    feature = tensor_e3nn.narrow(1, start_idx, length)  # extract all the channels corresponding to degree l
    feature = feature.view(-1, num_channels, (2 * l + 1))
    feature = feature.transpose(1, 2).contiguous()
    out.append(feature)
    start_idx = start_idx + length
tensor_equiformer_v2 = torch.cat(out, dim=1)

You can follow the above example to do the reverse.

liyy2 commented 8 months ago

hi, thank you for the detailed response. My question is does parity impact the model here? Should i use o3.Irreps('128x0e+128x1e+128x2e') or o3.Irreps('128x0e+128x1o+128x2e')

yilunliao commented 8 months ago

For EquiformerV2, we currently use SE(3), and therefore, we should use '128x0e+128x1e+128x2e'.