Install directly from GitHub with:
pip install git+https://github.com/mariogeiger/nequip-jax
from nequip_jax import NEQUIPLayerFlax # Flax version
from nequip_jax import NEQUIPLayerHaiku # Haiku version
Look at test.py for an example of how to stack the layers.
Optimization for large L
using https://arxiv.org/pdf/2302.03655.pdf.
With extra support of parity.
from nequip_jax import NEQUIPESCNLayerFlax # Flax version
from nequip_jax import NEQUIPESCNLayerHaiku # Haiku version