I am trying to recreate my nequip-torch model to nequip-jax. While the number of parameters matches exactly when I create the model, the numerical values do not. SO I am trying to assemble the layers myself one by one. One of the issue is that I could not figure out how to express the equivalent tensor product in nequip-jax
In torch I have following tensor product
when using e3nn-jax, instead of giving weights to the tp you simply follow the tp with a jax.flax.Linear layer. This is equivalent
yes you can though this code you pasted is doing a slight optimization, it's concatenating messages with tp of messages with sh (l>0) which is equivalent to tp messages with sh(>=0)
I am trying to recreate my nequip-torch model to nequip-jax. While the number of parameters matches exactly when I create the model, the numerical values do not. SO I am trying to assemble the layers myself one by one. One of the issue is that I could not figure out how to express the equivalent tensor product in nequip-jax In torch I have following tensor product
Where I provide weights manually to weight each path (?). But in Nequip-jax I found the following equivalent code:
I dont understand:
e3nn_jax.tensor_product
?[l for l in range(0, self.max_ell + 1)]
? As in original Nequip, edge attributes go from 0-l?Thanks