patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

Using an Equinox model outside Python - Deserialisation #664

Open stergiosba opened 8 months ago

stergiosba commented 8 months ago

Hello Patrick, again thank you for the nice package.

I wanted to ask whether there exists a way to deserialise an Equinox-trained model (in eqx format [json+bytes]) to be used for inference outside of Python, for example for deployment in a C++ project.

Thanks!

patrick-kidger commented 8 months ago

So in this case there should be nothing special about Equinox, and this just works like anything else in JAX.

For this you have a couple of options. The usual way right now is using jax2tf, and then using TensorFlow's capabilities for this.

However there is also a (not yet documented) jax.experimental.export, see here, which should offer a smoother experience. This discussion thread may also be helpful.

This is something that is actively being worked on in JAX, it seems :)