numagic / lumos

scalable accelerated optimal control
MIT License
16 stars 0 forks source link

saving/loading compiled function with jax backend #100

Open yunlongxu-numagic opened 1 year ago

yunlongxu-numagic commented 1 year ago

With Jax-backend models, sometimes we see large runtime saving, especially with GPUs. But the JIT compilation time could be quite costly. It would be very useful if we could save the jitted function and load it, much like we do with binaries for C++

Currently this is not supported. What is supported is AOT compilation, and re-use within the same process, which is not quite what we want. See here