Closed apark263 closed 3 months ago
Thanks for your interest in JAX!
Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.
One idea would be to add a new Python API jax.aot_compile
(probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so
file and .h
file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot
!
If you'd like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation
(https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation
object from the XLA client. In particular, it has a method GetSerializedProto()
that returns an xla.HloModule
proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)
PRs welcome!
Any updates?
Hey, I would be interested to contribute to this.
My use case is slightly different, as I simply want to be able to persistently save jit
ted-functions. I can provide more details about why I think this is important.
As far as I understand, there are two parts to achieve this:
python
part, where we have to replace the cache of _xla_callable
s with a persistent one. I believe this is easy to prototype. I attach below a duct-taped example, that saves the simplest jax
function. If you run the script twice, no tracing will happen in the second run. Naturally, however, when trying to serialise any non-trivial function, the following part becomes crucial.XLA
part that includes serialisation and deserialisation of jaxlib.xla_extension.Executable
s. Such an Executable
appears to be a PjRtStreamExecutorExecutable, that is created by a List of LocalExecutable
s that in turn can be created by HloModule
s. HloModule
s can be serialised and deserialised via these methods, so it appears that we have everything we need!(?)@hawkinsp do these sound sensible to you?
@skye @hawkinsp any update on the comment from @nrontsis ? Is https://github.com/google/jax/pull/7207 actually part of the solution or will this only support TPUs?
Yes, my summer intern implemented a persistent on-disk cache for jit'd functions! (Including pmap, xmap, and pjit as well.) Unfortunately this is TPU-only for now, because XLA only implements de/serialization methods for TPU executables. @nrontsis notes that the HLO module of an executable can be serialized, but the cache serializes the compiled executable itself, so that no compilation is required when loading from the cache.
I'm guessing that you would benefit from GPU support? We've raised GPU executable serialization as feature request with the XLA team, and requests like this help bump the prioritization :)
Nice work of the intern! @colemanliyah
My jit times are slowly getting out of hand (especially on CPU). Current jit time is ~40s, vmap over a Jacobian.T * Jacobian * v
operator, Gauss-Newton Hessian, with linearize once and linear_transpose, to be investigated but I couldn't find a small reproducible example at the moment. I'll share one if I can reproduce something but I think there are also quite some data that is constant folded which slows it down (only jitting the function f (not the J^T J v operator) also takes 8s).
There are two use cases at the moment for me from which I could benefit. The first would give me a bigger benefit at the moment. The second might be more critical in the midterm future.
PS: according to https://github.com/google/jax/issues/1566#issuecomment-546155141 CPU serialization is supported by XLA. With the ground work done for TPUs this might not be too much work?
I'm watching this issue for the usecase raised in the original issue:
compile the XLA into executable functions and link into an .so using the approach in tensorflow/compiler/aot
This would be an awesome path for deployment on cpu without any python dependencies.
@froystig saw you mentioning this in #7733, so I thought I'd ask here.
Regarding:
Discussion in #476 has focused on extracting some form of serialized executable after compilation. That's out of scope here
I'm wondering if the following combination might provide users with requirements mentioned in this issue a feasible workaround (for CPU only) until it is natively supported by jax:
1) use jax2tf
to convert to standard tf function
2) use tfcompile
to create the executable
Would this work? Are there issues I'm overlooking other than:
cc @gnecula. I'm not super familiar with tfcompile
, but this sounds plausible!
The jax2tf
path does allow you to serialize the computation so that later you can reload it and recompile. You do not need to use TF to compile, you can use jax2tf.call_tf
. Essentially, the path is jax2tf.convert
-> saved_model.save
to save a serialized form of the computation. Then saved_model.load
-> jax2tf.call_tf
to load it back, compile and execute. Under the hood, jax2tf.call_tf
will use TF to do the compilation. This path should for for all devices, CPU/GPU/TPU.
However, this will not save time because jax2tf.call_tf
will essentially cost as much as compilation.
@gnecula thanks for your inputs.
My main target (which I believe is also shared by some other users who commented here) is to reduce the jit overhead for the same function that are going to be executed repeated for different inputs, by compiling it and caching the executable ahead of time. This would be particularly important for programs whose jit compilation time is non-negligible compared to its execution time (I think this can be the case for many physics based models which might have lots of unary ops)
So instead of using jax2tf.call_tf
, I was more thinking of:
1) using jax2tf
to convert to tf function
2) using Tensorflow to export the graph from tf function
3) using tfcompile
to compile the graph into executable.
Would this avoid incurring compilation cost in the future to execute the same graph?
ps: I also stumbled upon another potential workaround, which is jax
-> tf function
-> onnx
, which could potentially also remove the compilation cost, but it would then require something like an onnx runtime (probably a price I'm willing to pay)
If the goal is only to compile ahead of time, why is serialization/export needed?
If the goal is only to compile ahead of time, why is serialization/export needed?
A concrete example that I have: when running Alphafold 2 to generate a structure for a protein, the first step in the prediction is to use JAX to compile a model specialized to the length of the input sequence of the protein. In the case of very large proteins, the compilation step is negligible compared to inference. However, in my use case, I only care about input sequences that are quite small, but I want to process a ton of them, and I don't know the full list that I want to process ahead of time. In this case, for each individual input, the compilation time is as much as ~15x as long as the inference time.
I would like to compile a specialized JAX executable for each of the short lengths that I'm interested in, cache those to the filesystem, and then load them at inference time, instead of having to recompile them every time I start up a fresh run of Alphafold 2.
I expect that tfcompile
should be able to work on the output of jax2tf
, but I have never tried it.
It seems that it would generate a library that packages the generated code with its runtime dependencies (e.g., kernels). One possible wrinkle is if you want to load multiple such libraries into the same executable, there may be symbol clashes. But it should be easy to try it out. I hope it works for your use case.
@skye, is the TPU work by your intern available for public development anywhere?
Hi @xloem, if you wish to use the persistent compilation cache on a TPU VM, you can enable it like this:
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("/path/name/here", max_cache_size_bytes=32 * 2**30)
Put that somewhere in your program before you start running JAX operations.
Is this what you meant by public development? (This reminds me that we still need to publish public documentation on it!)
Thanks so much!
I wasn't sure if it was to the point of usability or not, it's great to learn it is.
Hi @xloem, if you wish to use the persistent compilation cache on a TPU VM, you can enable it like this:
from jax.experimental.compilation_cache import compilation_cache as cc cc.initialize_cache("/path/name/here", max_cache_size_bytes=32 * 2**30)
Put that somewhere in your program before you start running JAX operations.
Is this what you meant by public development? (This reminds me that we still need to publish public documentation on it!)
Hi @skye, if I understand correctly, this is only available for TPU right? (also just did a quick experiment to validate this on CPU). Are there any plans / is it even possible to make it work similarly for GPU and CPU?
That's correct, it's TPU-only for now. It's possible to make it work for GPU and CPU -- it requires that XLA provides a way to de/serialize compiled GPU and CPU executables. I think this will eventually be supported at least on GPU, but there's no timeline for it yet.
Any updates on the compile cache for GPU? Is it still TPU-only?
It seems additional dev effort is needed here. Another missing component is colab notebooks, which use a different TPU interface without support for this.
No updates yet unfortunately. There are internal discussions happening around both the GPU and TPU Colab runtimes, which both need to be updated to new interfaces before we can support executable serialization (so it goes), but still no timelines.
The
jax2tf
path does allow you to serialize the computation so that later you can reload it and recompile. You do not need to use TF to compile, you can usejax2tf.call_tf
. Essentially, the path isjax2tf.convert
->saved_model.save
to save a serialized form of the computation. Thensaved_model.load
->jax2tf.call_tf
to load it back, compile and execute. Under the hood,jax2tf.call_tf
will use TF to do the compilation. This path should for for all devices, CPU/GPU/TPU.However, this will not save time because
jax2tf.call_tf
will essentially cost as much as compilation.
This looks relevant to my use case; I would also love the compilation caching features that others are after, but my primary concern atm is rapid iteration on a codebase, the ability to run and evaluate, and perform student-teacher dynamics on these model variants. Thus; I want to save my weights, and my compute graphs, without any reference to the code im editing.
It sounds like this jax2tf route is currently the easiest method of achieving that; but I thought id throw it out there for others to reflect upon. Rather than this coupling to tf and whatever limits it may impose, being able to convert some type of intermediate jax representation to/from json would be a lot cleaner I suppose. Not a weekend project probably, but it also does not sound too complicated either? Curious how others judge the complexity/utility of this.
Hello! I'm working on an optimization-through-differentiable-simulation use case where the compilation time is taking up a significant fraction of the overall runtime (~300 seconds for compilation, then another ~500 s for optimization). In this setting, the simulator doesn't change in between experiments, but we might want to run the optimizer with different initial conditions, hyperparameters, etc.. It would be really great if we had some way to compile either the simulator or optimizer step function and have that compiled function persist between program executions. This would really help speed up our development process, so I just wanted to check in and ask if there's been any progress on this feature or any plan to add it to JAX?
I depend on a large codebase of standard Jax code (in Brax), and I'm waiting for several minutes to compile functions with identical inputs to what was compiled in previous runs. Is there as of today any way to do this kind of caching on GPU? Is this still blocked by XLA? I would be interested in contributing to landing this feature.
Also, could this possibly be achieved by just serializing whatever code is in memory in Python?
It seems that serialization of compiled functions is now working on GPU: https://github.com/google/jax/discussions/13736#discussioncomment-5887985
Any updates for CPU?
As far as I can tell, the current version of JAX doesn't support serialization for CPU functions, but maybe there's some magic flag I need to pass?
Actually, on the CPU front, is there even a hacky solution? E.g. one that might have an inconsistent serialization format?
I'm at the point where JAX CPU compile time is a large fraction of my workflow, and I'm considering switching to a client / server model just to keep alive a persistent process that owns the compiled JAX functions. If there's any way to avoid this, even a brittle solution, I'd be relieved.
Thanks!
Actually, on the CPU front, is there even a hacky solution? E.g. one that might have an inconsistent serialization format?
I'm at the point where JAX CPU compile time is a large fraction of my workflow, and I'm considering switching to a client / server model just to keep alive a persistent process that owns the compiled JAX functions. If there's any way to avoid this, even a brittle solution, I'd be relieved.
Thanks!
Seconded. Compilation times and the lacking caching support are really the number one usability issue for JAX for me. Might be that its somewhat specific to my JAX use case; when working on DL applications I dont really expect results after a few epochs of training anyway so a few mins of compilation overhead dont make much of a difference; plus you tend to be working with fairly plug and play pieces of code so there is little need to go back and forth to debug individual JAX expressions. But I imagine for pretty much any other use case, minutes of compile time really ruins your day.
Thirded. It seems to me that in numerical computing, compilation times are the one big disadvantage of jax vs julia.
FYI, I tried the jax2tf
approach mentioned earlier in this thread, and that route appears to be broken.
E.g., the usage example in the TF docs crashes:
BTW, I get the same error with this simple example:
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf
def add_jax(x, y):
return jnp.add(x, y)
add_tf = jax2tf.convert(add_jax)
x = tf.constant(1.0)
y = tf.constant(2.0)
result = add_tf(x, y)
The error you are seeing with jax2tf is because the current version depends on not-yet-released features in TF. You can work around this by using tf-nightly.
However, as discussed earlier in this thread jax2tf will only save the tracing and HLO generation time. It won't reduce the compilation time.
Export is now supported natively in JAX; refer to the jax.export
module, with examples in the Exporting and serialization guide.
Context:
I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later -- ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).
My initial idea was to:
.so
using the approach intensorflow/compiler/aot
Assuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?