With Jax-backend models, sometimes we see large runtime saving, especially with GPUs. But the JIT compilation time could be quite costly. It would be very useful if we could save the jitted function and load it, much like we do with binaries for C++
Currently this is not supported. What is supported is AOT compilation, and re-use within the same process, which is not quite what we want. See here
With Jax-backend models, sometimes we see large runtime saving, especially with GPUs. But the JIT compilation time could be quite costly. It would be very useful if we could save the jitted function and load it, much like we do with binaries for C++
Currently this is not supported. What is supported is AOT compilation, and re-use within the same process, which is not quite what we want. See here