based-robotics / jaxadi

Transforms your CasADi functions into batchable JAX-compatible functions. By combining the power of CasADi with the flexibility of JAX, JAXADi enables the creation of efficient code that runs smoothly on CPUs, GPUs, and TPUs.
https://based-robotics.github.io/jaxadi/
MIT License
115 stars 3 forks source link

The reverse too? #12

Closed mawbray closed 1 month ago

mawbray commented 1 month ago

Love the translation and conversion from casADi to jax - thanks! Are there plans to enable the reverse (i.e. from jax to casADi)? This would be pretty functional for embedding JAX ML objects into e.g. NLP. In my opinion, this is something that takes a fair amount of effort through callbacks at the moment ...

traversaro commented 1 month ago

In case this is useful to anyone, I did some experiment in that direction some time ago in https://github.com/traversaro/experimental-jax-casadi/blob/main/exploration.ipynb . I quickly stopped as the jax primitive basically mirror the XLA one, so there are complex primitives like gather that do not have direct casadi counterparts, and I was not able to emulate them properly in casadi. However, I would be really glad if someone was able to solve this problem.

lvjonok commented 1 month ago

Thanks @traversaro for pointing to some experiments you have done before!

@mawbray, I do not think the reverse is in the scope of our library, we will not implement it upon specific need. In case you need to embed your neural network trained with jax in casadi pipeline you may follow the similar approach as in l4casadi.

Thank you for the interest!