fabiannagel / schnax

An implementation of SchNet in JAX and JAX-MD.
16 stars 2 forks source link

build

schnax: SchNet in JAX and JAX-MD

This is a re-implementation of the SchNet neural network architecture in JAX, haiku, and JAX-MD. schnax is intended as a drop-in replacement for the original pytorch implementation, allowing the use of trained weights obtained with SchNetPack.

References