jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

Allow Jaxprs to be cloudpickle-able #9444

Open chaserileyroberts opened 2 years ago

chaserileyroberts commented 2 years ago

Right now, if you try to pickle a jaxpr, you are given an error.

>>> import jax
>>> j = jax.make_jaxpr(f)(1)
>>> j
{ lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
>>> import cloudpickle
>>> cloudpickle.dumps(j)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/home/chase/anaconda3/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 563, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object

However, there isn't much that needs to change. Simply mapping the jaxpr.eqns to new_jaxpr_eqn that doesn't pass its source info allows the jaxpr to be pickleable.

j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
cloudpickle.dumps(j) # Works like a charm!
shoyer commented 2 years ago

Would you mind commenting on the underlying use-case here?

chaserileyroberts commented 2 years ago

Basically what I'm experimenting with is a way to decompose a jax program in to multiple jax+mpi programs that run distributed.

big_jaxpr = jax.make_jaxpr(some_func)(*args, **kwargs)
many_smaller_jaxprs = decompose_problem(big_jaxpr)
my_methods = [jax.jit(jax.core.jaxpr_as_func(j)) for j in many_smaller_jaxprs]
for i, jaxpr in enumerate(many_smaller_jaxprs):
  client.submit(my_methods, *args, **kwargs, worker=i)

Since all of the functions in my_methods have the jaxprs as local variables, those get pickeled when sent off to the remote worker. Right now I'm using the above hack on all of the smaller jaxprs, which works fine for now but it would be nice to have a tested and supported method.

Also, one thing that surprised me very much was that jited methods were always correctly cached on the remote workers. I had no issue with recompilation which I really wasn't expecting given the fact I'm mixing JAX + MPI and Dask.

shoyer commented 2 years ago

OK, interesting. I'll let core JAX team comment on the feasibility here, but I guess it could make sense to either make Traceback pickleable, or make including tracebacks in JAXprs optional.

chaserileyroberts commented 2 years ago

Tracebacks in general are not pickleable sadly since they reference the full memory stack.

https://stackoverflow.com/questions/6132469/why-cant-i-pickle-an-errors-traceback-in-python

Your other suggestion of making tracebacks in Jaxprs optional is already the case I believe. If you don't include a source_info argument when calling new_jaxpr_eqn, no traceback is included. We could perhaps utilize that and add a include_tracebacks=False option to make_jaxpr that forces all of the eqns to not have tracebacks.

shoyer commented 2 years ago

Traceback here is a custom XLA/JAX things, so in principle we could override it: https://github.com/tensorflow/tensorflow/blob/fb91f402331605db55f1cda9603e18835245e6d1/tensorflow/compiler/xla/python/traceback.cc

(that said, it does currently include Python stack frames)

chaserileyroberts commented 2 years ago

Actually, upon further experimenting, I'm not sure pickleing is viable for the usecase I described above.

>>> import jax
>>> j = jax.make_jaxpr(lambda a, b: a + b)(1.0, 2.0)
>>> import cloudpickle
>>> j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
>>> s = cloudpickle.dumps(j)
>>> j2 = cloudpickle.loads(s)
>>> jax.core.jaxpr_as_fun(j2)(1.0, 2.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 147, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 330, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 272, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 275, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 591, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 92, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 202, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 195, in cached
    return f(*args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 111, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
    module = mlir.lower_jaxpr_to_module(
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 409, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 549, in lower_jaxpr_to_fun
    out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 628, in jaxpr_subcomp
    raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu

Since the Primitives are actual python objects, and the translation rules are determined by dictionary lookups of the objects by reference, those references aren't maintained between a pickle dump/load. I'm not sure this can be supported without a huge amount of structural changes to JAX.

hawkinsp commented 2 years ago

Well: that would suggest that you'd want to pickle primitives by name not by object identity...

PhilipVinc commented 2 years ago

I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem.

EDIT: yeah, but this won't work....

chaserileyroberts commented 2 years ago

I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem.

I wish it was that simple. Normal pickle has it's own struggles.

>>> j2 = pickle.dumps(j)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
_pickle.PicklingError: Can't pickle <function <lambda> at 0x7f55e2bf6550>: attribute lookup <lambda> on jax._src.lax.utils failed
hawkinsp commented 2 years ago

I meant: add logic to core.Primitive to have it pickle by name, not by identity.

hawkinsp commented 2 years ago

To add to the comments about Traceback above: yes, Stephan is right, that theTraceback objects in jaxprs are JAX-internal traceback objects, not Python tracebacks. They exist for one reason only: they are optimized to be fast to collect.

e.g.:

In [1]: import jax

In [3]: Traceback = jax._src.lib.xla_client.Traceback

In [7]: %timeit x = Traceback.get_traceback()
798 ns ± 11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [8]: import inspect

In [10]: %timeit y = inspect.stack()
5.36 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [11]: import traceback

In [13]: %timeit x = traceback.extract_stack()
99.2 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

i.e. calling the JAX traceback is at least 3 orders of magnitude faster than inspect.stack() for the handful of stack frames in my ipython session, and about 2 orders of magnitude faster than traceback.extract_stack().

The tracebacks are simply a list of pairs of (code, lasti) values, where code is a types.CodeType and lasti is an int. We store code values because they are extremely cheap to collect from the interpreter stack, and we do not attempt to interpret them in any way when they are collected. We defer turning the code and lasti values into values meaningful to the user until we need them.

Now, if I were to serialize the traceback values, I'd want to look at two things:

The relevant method in the JAX tree is mainly source_info_util.user_frame(). You could in essence call that at serialization time and replace the source info value with a new alternative value that only contains a single frame of interest.

EelcoHoogendoorn commented 1 year ago

Just throwing my hat in the ring here for a similar feature request.

My use case is wanting to do model cloning on model variants of an evolving codebase; ie, make some architecture tweaks on my RL model, without having to learn from scratch. Making every little thing configurable does not really scale, nor really works since you are constantly breaking backwards compatibility as you evolve your model with progressive insight. One way to do this is to commit every little code change separately and then run that different code version in a separate python process and then clone by exchanging data across processes; but obviously that would also be a horrible pain, compared to the elegance of just being able to just read/write pure jax model.apply functions to disc.

I dont know how hard it will be to obtain a serializable representation; but I do think being able to do so would allow leveraging JAXs functional abilities in a nice way compared to other frameworks.

chaserileyroberts commented 1 year ago

I work on jax full time now so I'm going to try and lead this.

The two issues are the traceback thing and the global dictionary lookups. The traceback thing can be solved I think by just implementing the encoding/decoding apis. The global dictionary one is harder to solve unfortunately. I'm not sure how to make clouldpickle treat the primitives like "singletons".

chaserileyroberts commented 1 year ago

Ok so hacking in this change in core.py

# Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}

def primitives_table_get(namespace: str, name: str):
  if (namespace, name) in _PRIMITIVES_TABLE_:
    return _PRIMITIVES_TABLE_[(namespace, name)]
  raise NotImplementedError(
    f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?).")

def primitives_table_set(namespace: str, name: str, primitive: Primitive):
  assert (namespace, name) not in _PRIMITIVES_TABLE_, (
    f"The op name {(namespace, name)} is already taken. "
    "Try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
  )
  _PRIMITIVES_TABLE_[(namespace, name)] = primitive

class Primitive:
  name: str
  namespace: str
  # set for multi-output primitives.
  multiple_results: bool = False
  # set for call primitives processed in final style.
  call_primitive: bool = False
  # set for map primitives processed in final style.
  map_primitive: bool = False

  def __init__(self, name: str, namespace: str='__jax_internal__'):
    self.name = name
    self.namespace = namespace
    primitives_table_set(namespace, name, self)

  def __reduce__(self):
    return primitives_table_get, (self.namespace, self.name)

And this change just somewhere.

import jaxlib
jaxlib.xla_extension.Traceback.__reduce__ = lambda a: (lambda: None, ())

And now pickle works like a charm.

jxpr = jax.make_jaxpr(lambda a: a + a)(1)
res = cloudpickle.dumps(jxpr)
new_jxpr = cloudpickle.loads(res)
jax.core.jaxpr_as_fun(new_jxpr)(2)
# [Array(4, dtype=int32, weak_type=True)]

Nothing broke in vanilla jax so there are no obvious name collisions. I image however something in google3 will collide. When that happens, using the namespace='...' trick should fix most of the issues quickly.

Will raise a PR tomorrow.

jjyyxx commented 5 months ago

https://gist.github.com/jjyyxx/f64e28f6ccc37c24af9fd17649710b26

I created a demo that successfully saves a traced JAX function using pickle only. This is important to me because during quick research, I often make many small tweaks without version control. Saving the traced JAX function provides a minimal interface to reproduce results later, even if the source code has changed.

Previously, I used ONNX, but it's slow, TensorFlow-dependent, and has many unsupported operations. Saving Jaxpr is much lighter and allows for separating the computation graph from data. I can save the Jaxpr once and only save parameters periodically during training.

Additionally, saving Jaxpr instead of compiled binaries or lowered IR gives me almost full control over further inference. It can be jitted, used on CPU/GPU, vmap-transformed, and exported to ONNX.

While it works for me, there are some unhandled edge cases:

  1. Partial eval function for custom_* (notably jax.nn.relu) is dropped. I lack deep understanding of JAX internals to handle this.
  2. Custom primitives are not handled.
  3. Host callbacks are not allowed.

The hackiest part is mapping primitive names to primitives, which currently involves scanning through modules.

bionicles commented 3 months ago

Hey, I agree the jaxpr makes sense to pickle since it's a portable IR for jax jit functions. However, if we just had "fn_from_jaxpr" then we could save the jaxprs as strings instead of pickles, which would facilitate human readability of serialized jit functions.

One idea to do this, would be

  1. Test a bijection between PjitFunction and Jaxpr
  2. Rename make_jaxpr to jaxpr_from_fn and ideally remove the requirement to pass input
  3. Define fn_from_jaxpr to pass the test (how?)

Created issue:

jax jitted functions cloudpickled work but include some error messages #537 in cloudpipe / cloudpickle repo

Link: https://github.com/cloudpipe/cloudpickle/issues/537

Pic: Screenshot 2024-06-28 061537