ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

Onnx support for the ICNNs #592

Open malteal opened 2 weeks ago

malteal commented 2 weeks ago

Hi,

I have a question about onnx support for the jax-version of ICNNs

Does anyone know of support or has tried to export the jax-version of the ICNN to onnx with the gradient of the network being the output of the onnx model?

I have a Pytorch-version of ICNNs that can be exported to onnx using onnxruntime-training, however, I have not been able to do it using the ott framework.

Best regrads, Malte