google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.04k stars 2.66k forks source link

Support serialization/export of compiled computations #476

Open apark263 opened 5 years ago

apark263 commented 5 years ago

Context:

I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later -- ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).

My initial idea was to:

Assuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?

hawkinsp commented 5 years ago

Thanks for your interest in JAX!

Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We've discussed things along these lines, but haven't done anything concrete yet.

One idea would be to add a new Python API jax.aot_compile (probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so file and .h file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot!

If you'd like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation (https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation object from the XLA client. In particular, it has a method GetSerializedProto() that returns an xla.HloModule proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)

PRs welcome!

2sin18 commented 3 years ago

Any updates?

nrontsis commented 3 years ago

Hey, I would be interested to contribute to this.

My use case is slightly different, as I simply want to be able to persistently save jitted-functions. I can provide more details about why I think this is important.

As far as I understand, there are two parts to achieve this:

@hawkinsp do these sound sensible to you?

Serialisation example: ```python from copy import copy from jax import tree_flatten from jax.linear_util import WrappedFun, Store import jaxlib.xla_extension import dill as pickle import jax.numpy as np from jax.api import jit, vmap from jax.interpreters import xla import jax._src.util from jax.lazy import ArrayVar # Hack to avoid pickling error: : it's not found as jax._src.util.ArrayVar jax._src.util.ArrayVar = ArrayVar # Hacks to Serialise PyTrees class PythonTree: def __init__(self, definition): self.definition = definition class PyTreeStar: __repr__ = lambda _: "*" def pytree_to_serialisable(obj): if isinstance(obj, jaxlib.xla_extension.PyTreeDef): return PythonTree(obj.unflatten(obj.num_leaves * [PyTreeStar(), ])) else: return obj def serialisable_to_pytree(obj): if hasattr(obj, "definition"): return tree_flatten(obj.definition)[1] else: return obj STATIC_COMPILATION_IDENTIFIER = "_compiled_statically_" original_xla_callable = copy(xla._xla_callable) def xla_callable(fun: WrappedFun, device, backend, name, donated_invars, *arg_specs): if name[:len(STATIC_COMPILATION_IDENTIFIER)] != STATIC_COMPILATION_IDENTIFIER: return original_xla_callable(fun, device, backend, name, donated_invars, *arg_specs) filename = name + ".dill" try: compiled_function, store_values = pickle.load(open(filename, "rb")) stores = tuple([Store() for _ in store_values]) for store, value in zip(stores, store_values): store.store(serialisable_to_pytree(value)) fun.populate_stores(stores) except IOError: compiled_function = original_xla_callable(fun, device, backend, name, donated_invars, *arg_specs) store_values = [pytree_to_serialisable(store.val) for store in fun.stores] pickle.dump((compiled_function, store_values), open(filename, "wb")) return compiled_function xla._xla_callable = xla_callable xla._xla_callable.most_recent_entry = lambda: None def persistent_jit(function, unique_name): f = copy(function) f.__name__ = STATIC_COMPILATION_IDENTIFIER + unique_name jitted_function = jit(f) return jitted_function def my_function(x): print("tracing") return x # This works # return np.sum(np.square(x)) # This doesn't compiled_function = persistent_jit(vmap(my_function), unique_name="my_first_aot_compiled_function") print(compiled_function(np.zeros(3))) print(compiled_function(np.ones(3))) ```
tetterl commented 2 years ago

@skye @hawkinsp any update on the comment from @nrontsis ? Is https://github.com/google/jax/pull/7207 actually part of the solution or will this only support TPUs?

skye commented 2 years ago

Yes, my summer intern implemented a persistent on-disk cache for jit'd functions! (Including pmap, xmap, and pjit as well.) Unfortunately this is TPU-only for now, because XLA only implements de/serialization methods for TPU executables. @nrontsis notes that the HLO module of an executable can be serialized, but the cache serializes the compiled executable itself, so that no compilation is required when loading from the cache.

I'm guessing that you would benefit from GPU support? We've raised GPU executable serialization as feature request with the XLA team, and requests like this help bump the prioritization :)

tetterl commented 2 years ago

Nice work of the intern! @colemanliyah

My jit times are slowly getting out of hand (especially on CPU). Current jit time is ~40s, vmap over a Jacobian.T * Jacobian * v operator, Gauss-Newton Hessian, with linearize once and linear_transpose, to be investigated but I couldn't find a small reproducible example at the moment. I'll share one if I can reproduce something but I think there are also quite some data that is constant folded which slows it down (only jitting the function f (not the J^T J v operator) also takes 8s).

There are two use cases at the moment for me from which I could benefit. The first would give me a bigger benefit at the moment. The second might be more critical in the midterm future.

  1. persistent caching for CPUs for prototyping/developing on a non-GPU machine
  2. persistent caching for GPUs for production runs where the python code will be called multiple times. This could be resolved by not leaving the python code though.

PS: according to https://github.com/google/jax/issues/1566#issuecomment-546155141 CPU serialization is supported by XLA. With the ground work done for TPUs this might not be too much work?

dmaniry commented 2 years ago

I'm watching this issue for the usecase raised in the original issue:

compile the XLA into executable functions and link into an .so using the approach in tensorflow/compiler/aot

This would be an awesome path for deployment on cpu without any python dependencies.

yunlongxu-numagic commented 2 years ago

@froystig saw you mentioning this in #7733, so I thought I'd ask here.

Regarding:

Discussion in #476 has focused on extracting some form of serialized executable after compilation. That's out of scope here

I'm wondering if the following combination might provide users with requirements mentioned in this issue a feasible workaround (for CPU only) until it is natively supported by jax: 1) use jax2tf to convert to standard tf function 2) use tfcompile to create the executable

Would this work? Are there issues I'm overlooking other than:

skye commented 2 years ago

cc @gnecula. I'm not super familiar with tfcompile, but this sounds plausible!

gnecula commented 2 years ago

The jax2tf path does allow you to serialize the computation so that later you can reload it and recompile. You do not need to use TF to compile, you can use jax2tf.call_tf. Essentially, the path is jax2tf.convert -> saved_model.save to save a serialized form of the computation. Then saved_model.load -> jax2tf.call_tf to load it back, compile and execute. Under the hood, jax2tf.call_tf will use TF to do the compilation. This path should for for all devices, CPU/GPU/TPU.

However, this will not save time because jax2tf.call_tf will essentially cost as much as compilation.

yunlongxu-numagic commented 2 years ago

@gnecula thanks for your inputs.

My main target (which I believe is also shared by some other users who commented here) is to reduce the jit overhead for the same function that are going to be executed repeated for different inputs, by compiling it and caching the executable ahead of time. This would be particularly important for programs whose jit compilation time is non-negligible compared to its execution time (I think this can be the case for many physics based models which might have lots of unary ops)

So instead of using jax2tf.call_tf, I was more thinking of: 1) using jax2tf to convert to tf function 2) using Tensorflow to export the graph from tf function 3) using tfcompile to compile the graph into executable.

Would this avoid incurring compilation cost in the future to execute the same graph?

ps: I also stumbled upon another potential workaround, which is jax -> tf function -> onnx, which could potentially also remove the compilation cost, but it would then require something like an onnx runtime (probably a price I'm willing to pay)

froystig commented 2 years ago

If the goal is only to compile ahead of time, why is serialization/export needed?

ncoish commented 2 years ago

If the goal is only to compile ahead of time, why is serialization/export needed?

A concrete example that I have: when running Alphafold 2 to generate a structure for a protein, the first step in the prediction is to use JAX to compile a model specialized to the length of the input sequence of the protein. In the case of very large proteins, the compilation step is negligible compared to inference. However, in my use case, I only care about input sequences that are quite small, but I want to process a ton of them, and I don't know the full list that I want to process ahead of time. In this case, for each individual input, the compilation time is as much as ~15x as long as the inference time.

I would like to compile a specialized JAX executable for each of the short lengths that I'm interested in, cache those to the filesystem, and then load them at inference time, instead of having to recompile them every time I start up a fresh run of Alphafold 2.

gnecula commented 2 years ago

I expect that tfcompile should be able to work on the output of jax2tf, but I have never tried it.

It seems that it would generate a library that packages the generated code with its runtime dependencies (e.g., kernels). One possible wrinkle is if you want to load multiple such libraries into the same executable, there may be symbol clashes. But it should be easy to try it out. I hope it works for your use case.

xloem commented 2 years ago

@skye, is the TPU work by your intern available for public development anywhere?

skye commented 2 years ago

Hi @xloem, if you wish to use the persistent compilation cache on a TPU VM, you can enable it like this:

from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("/path/name/here", max_cache_size_bytes=32 * 2**30)

Put that somewhere in your program before you start running JAX operations.

Is this what you meant by public development? (This reminds me that we still need to publish public documentation on it!)

xloem commented 2 years ago

Thanks so much!

I wasn't sure if it was to the point of usability or not, it's great to learn it is.

yunlongxu-numagic commented 2 years ago

Hi @xloem, if you wish to use the persistent compilation cache on a TPU VM, you can enable it like this:

from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("/path/name/here", max_cache_size_bytes=32 * 2**30)

Put that somewhere in your program before you start running JAX operations.

Is this what you meant by public development? (This reminds me that we still need to publish public documentation on it!)

Hi @skye, if I understand correctly, this is only available for TPU right? (also just did a quick experiment to validate this on CPU). Are there any plans / is it even possible to make it work similarly for GPU and CPU?

skye commented 2 years ago

That's correct, it's TPU-only for now. It's possible to make it work for GPU and CPU -- it requires that XLA provides a way to de/serialize compiled GPU and CPU executables. I think this will eventually be supported at least on GPU, but there's no timeline for it yet.

samuela commented 2 years ago

Any updates on the compile cache for GPU? Is it still TPU-only?

xloem commented 2 years ago

It seems additional dev effort is needed here. Another missing component is colab notebooks, which use a different TPU interface without support for this.

skye commented 2 years ago

No updates yet unfortunately. There are internal discussions happening around both the GPU and TPU Colab runtimes, which both need to be updated to new interfaces before we can support executable serialization (so it goes), but still no timelines.

EelcoHoogendoorn commented 1 year ago

The jax2tf path does allow you to serialize the computation so that later you can reload it and recompile. You do not need to use TF to compile, you can use jax2tf.call_tf. Essentially, the path is jax2tf.convert -> saved_model.save to save a serialized form of the computation. Then saved_model.load -> jax2tf.call_tf to load it back, compile and execute. Under the hood, jax2tf.call_tf will use TF to do the compilation. This path should for for all devices, CPU/GPU/TPU.

However, this will not save time because jax2tf.call_tf will essentially cost as much as compilation.

This looks relevant to my use case; I would also love the compilation caching features that others are after, but my primary concern atm is rapid iteration on a codebase, the ability to run and evaluate, and perform student-teacher dynamics on these model variants. Thus; I want to save my weights, and my compute graphs, without any reference to the code im editing.

It sounds like this jax2tf route is currently the easiest method of achieving that; but I thought id throw it out there for others to reflect upon. Rather than this coupling to tf and whatever limits it may impose, being able to convert some type of intermediate jax representation to/from json would be a lot cleaner I suppose. Not a weekend project probably, but it also does not sound too complicated either? Curious how others judge the complexity/utility of this.

dawsonc commented 1 year ago

Hello! I'm working on an optimization-through-differentiable-simulation use case where the compilation time is taking up a significant fraction of the overall runtime (~300 seconds for compilation, then another ~500 s for optimization). In this setting, the simulator doesn't change in between experiments, but we might want to run the optimizer with different initial conditions, hyperparameters, etc.. It would be really great if we had some way to compile either the simulator or optimizer step function and have that compiled function persist between program executions. This would really help speed up our development process, so I just wanted to check in and ask if there's been any progress on this feature or any plan to add it to JAX?

marcelroed commented 1 year ago

I depend on a large codebase of standard Jax code (in Brax), and I'm waiting for several minutes to compile functions with identical inputs to what was compiled in previous runs. Is there as of today any way to do this kind of caching on GPU? Is this still blocked by XLA? I would be interested in contributing to landing this feature.

Also, could this possibly be achieved by just serializing whatever code is in memory in Python?

blurgyy commented 1 year ago

It seems that serialization of compiled functions is now working on GPU: https://github.com/google/jax/discussions/13736#discussioncomment-5887985

emchristiansen commented 8 months ago

Any updates for CPU?

As far as I can tell, the current version of JAX doesn't support serialization for CPU functions, but maybe there's some magic flag I need to pass?

emchristiansen commented 8 months ago

Actually, on the CPU front, is there even a hacky solution? E.g. one that might have an inconsistent serialization format?

I'm at the point where JAX CPU compile time is a large fraction of my workflow, and I'm considering switching to a client / server model just to keep alive a persistent process that owns the compiled JAX functions. If there's any way to avoid this, even a brittle solution, I'd be relieved.

Thanks!

EelcoHoogendoorn commented 8 months ago

Actually, on the CPU front, is there even a hacky solution? E.g. one that might have an inconsistent serialization format?

I'm at the point where JAX CPU compile time is a large fraction of my workflow, and I'm considering switching to a client / server model just to keep alive a persistent process that owns the compiled JAX functions. If there's any way to avoid this, even a brittle solution, I'd be relieved.

Thanks!

Seconded. Compilation times and the lacking caching support are really the number one usability issue for JAX for me. Might be that its somewhat specific to my JAX use case; when working on DL applications I dont really expect results after a few epochs of training anyway so a few mins of compilation overhead dont make much of a difference; plus you tend to be working with fairly plug and play pieces of code so there is little need to go back and forth to debug individual JAX expressions. But I imagine for pretty much any other use case, minutes of compile time really ruins your day.

gboehl commented 8 months ago

Thirded. It seems to me that in numerical computing, compilation times are the one big disadvantage of jax vs julia.

emchristiansen commented 8 months ago

FYI, I tried the jax2tf approach mentioned earlier in this thread, and that route appears to be broken. E.g., the usage example in the TF docs crashes:

Screenshot 2023-10-18 at 10 35 54

BTW, I get the same error with this simple example:

import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

def add_jax(x, y):
    return jnp.add(x, y)

add_tf = jax2tf.convert(add_jax)

x = tf.constant(1.0)
y = tf.constant(2.0)

result = add_tf(x, y)
gnecula commented 8 months ago

The error you are seeing with jax2tf is because the current version depends on not-yet-released features in TF. You can work around this by using tf-nightly.

However, as discussed earlier in this thread jax2tf will only save the tracing and HLO generation time. It won't reduce the compilation time.