Open chaserileyroberts opened 2 years ago
Would you mind commenting on the underlying use-case here?
Basically what I'm experimenting with is a way to decompose a jax program in to multiple jax+mpi programs that run distributed.
big_jaxpr = jax.make_jaxpr(some_func)(*args, **kwargs)
many_smaller_jaxprs = decompose_problem(big_jaxpr)
my_methods = [jax.jit(jax.core.jaxpr_as_func(j)) for j in many_smaller_jaxprs]
for i, jaxpr in enumerate(many_smaller_jaxprs):
client.submit(my_methods, *args, **kwargs, worker=i)
Since all of the functions in my_methods
have the jaxprs as local variables, those get pickeled when sent off to the remote worker. Right now I'm using the above hack on all of the smaller jaxprs, which works fine for now but it would be nice to have a tested and supported method.
Also, one thing that surprised me very much was that jit
ed methods were always correctly cached on the remote workers. I had no issue with recompilation which I really wasn't expecting given the fact I'm mixing JAX + MPI and Dask.
OK, interesting. I'll let core JAX team comment on the feasibility here, but I guess it could make sense to either make Traceback
pickleable, or make including tracebacks in JAXprs optional.
Traceback
s in general are not pickleable sadly since they reference the full memory stack.
https://stackoverflow.com/questions/6132469/why-cant-i-pickle-an-errors-traceback-in-python
Your other suggestion of making tracebacks in Jaxprs optional is already the case I believe. If you don't include a source_info
argument when calling new_jaxpr_eqn
, no traceback is included. We could perhaps utilize that and add a include_tracebacks=False
option to make_jaxpr
that forces all of the eqns
to not have tracebacks.
Traceback
here is a custom XLA/JAX things, so in principle we could override it:
https://github.com/tensorflow/tensorflow/blob/fb91f402331605db55f1cda9603e18835245e6d1/tensorflow/compiler/xla/python/traceback.cc
(that said, it does currently include Python stack frames)
Actually, upon further experimenting, I'm not sure pickleing is viable for the usecase I described above.
>>> import jax
>>> j = jax.make_jaxpr(lambda a, b: a + b)(1.0, 2.0)
>>> import cloudpickle
>>> j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
>>> s = cloudpickle.dumps(j)
>>> j2 = cloudpickle.loads(s)
>>> jax.core.jaxpr_as_fun(j2)(1.0, 2.0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 147, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 330, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 272, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 275, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 591, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 92, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 202, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 195, in cached
return f(*args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 111, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
module = mlir.lower_jaxpr_to_module(
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 409, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 549, in lower_jaxpr_to_fun
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 628, in jaxpr_subcomp
raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu
Since the Primitive
s are actual python objects, and the translation rules are determined by dictionary lookups of the objects by reference, those references aren't maintained between a pickle dump/load. I'm not sure this can be supported without a huge amount of structural changes to JAX.
Well: that would suggest that you'd want to pickle primitives by name not by object identity...
I think that is the difference between pickle
and cloudpickle
. Using the former should avoid the problem.
EDIT: yeah, but this won't work....
I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem.
I wish it was that simple. Normal pickle has it's own struggles.
>>> j2 = pickle.dumps(j)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
_pickle.PicklingError: Can't pickle <function <lambda> at 0x7f55e2bf6550>: attribute lookup <lambda> on jax._src.lax.utils failed
I meant: add logic to core.Primitive
to have it pickle by name, not by identity.
To add to the comments about Traceback
above: yes, Stephan is right, that theTraceback
objects in jaxprs are JAX-internal traceback objects, not Python tracebacks. They exist for one reason only: they are optimized to be fast to collect.
e.g.:
In [1]: import jax
In [3]: Traceback = jax._src.lib.xla_client.Traceback
In [7]: %timeit x = Traceback.get_traceback()
798 ns ± 11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [8]: import inspect
In [10]: %timeit y = inspect.stack()
5.36 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [11]: import traceback
In [13]: %timeit x = traceback.extract_stack()
99.2 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
i.e. calling the JAX traceback is at least 3 orders of magnitude faster than inspect.stack()
for the handful of stack frames in my ipython session, and about 2 orders of magnitude faster than traceback.extract_stack()
.
The tracebacks are simply a list of pairs of (code, lasti)
values, where code
is a types.CodeType
and lasti
is an int
. We store code
values because they are extremely cheap to collect from the interpreter stack, and we do not attempt to interpret them in any way when they are collected. We defer turning the code
and lasti
values into values meaningful to the user until we need them.
Now, if I were to serialize the traceback values, I'd want to look at two things:
The relevant method in the JAX tree is mainly source_info_util.user_frame()
. You could in essence call that at serialization time and replace the source info value with a new alternative value that only contains a single frame of interest.
Just throwing my hat in the ring here for a similar feature request.
My use case is wanting to do model cloning on model variants of an evolving codebase; ie, make some architecture tweaks on my RL model, without having to learn from scratch. Making every little thing configurable does not really scale, nor really works since you are constantly breaking backwards compatibility as you evolve your model with progressive insight. One way to do this is to commit every little code change separately and then run that different code version in a separate python process and then clone by exchanging data across processes; but obviously that would also be a horrible pain, compared to the elegance of just being able to just read/write pure jax model.apply functions to disc.
I dont know how hard it will be to obtain a serializable representation; but I do think being able to do so would allow leveraging JAXs functional abilities in a nice way compared to other frameworks.
I work on jax full time now so I'm going to try and lead this.
The two issues are the traceback thing and the global dictionary lookups. The traceback thing can be solved I think by just implementing the encoding/decoding apis. The global dictionary one is harder to solve unfortunately. I'm not sure how to make clouldpickle treat the primitives like "singletons".
Ok so hacking in this change in core.py
# Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}
def primitives_table_get(namespace: str, name: str):
if (namespace, name) in _PRIMITIVES_TABLE_:
return _PRIMITIVES_TABLE_[(namespace, name)]
raise NotImplementedError(
f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?).")
def primitives_table_set(namespace: str, name: str, primitive: Primitive):
assert (namespace, name) not in _PRIMITIVES_TABLE_, (
f"The op name {(namespace, name)} is already taken. "
"Try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
)
_PRIMITIVES_TABLE_[(namespace, name)] = primitive
class Primitive:
name: str
namespace: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
def __init__(self, name: str, namespace: str='__jax_internal__'):
self.name = name
self.namespace = namespace
primitives_table_set(namespace, name, self)
def __reduce__(self):
return primitives_table_get, (self.namespace, self.name)
And this change just somewhere.
import jaxlib
jaxlib.xla_extension.Traceback.__reduce__ = lambda a: (lambda: None, ())
And now pickle works like a charm.
jxpr = jax.make_jaxpr(lambda a: a + a)(1)
res = cloudpickle.dumps(jxpr)
new_jxpr = cloudpickle.loads(res)
jax.core.jaxpr_as_fun(new_jxpr)(2)
# [Array(4, dtype=int32, weak_type=True)]
Nothing broke in vanilla jax so there are no obvious name collisions. I image however something in google3 will collide. When that happens, using the namespace='...'
trick should fix most of the issues quickly.
Will raise a PR tomorrow.
https://gist.github.com/jjyyxx/f64e28f6ccc37c24af9fd17649710b26
I created a demo that successfully saves a traced JAX function using pickle only. This is important to me because during quick research, I often make many small tweaks without version control. Saving the traced JAX function provides a minimal interface to reproduce results later, even if the source code has changed.
Previously, I used ONNX, but it's slow, TensorFlow-dependent, and has many unsupported operations. Saving Jaxpr is much lighter and allows for separating the computation graph from data. I can save the Jaxpr once and only save parameters periodically during training.
Additionally, saving Jaxpr instead of compiled binaries or lowered IR gives me almost full control over further inference. It can be jitted, used on CPU/GPU, vmap-transformed, and exported to ONNX.
While it works for me, there are some unhandled edge cases:
The hackiest part is mapping primitive names to primitives, which currently involves scanning through modules.
Hey, I agree the jaxpr makes sense to pickle since it's a portable IR for jax jit functions. However, if we just had "fn_from_jaxpr" then we could save the jaxprs as strings instead of pickles, which would facilitate human readability of serialized jit functions.
One idea to do this, would be
make_jaxpr
to jaxpr_from_fn
and ideally remove the requirement to pass inputfn_from_jaxpr
to pass the test (how?)Created issue:
jax jitted functions cloudpickled work but include some error messages #537 in cloudpipe / cloudpickle repo
Link: https://github.com/cloudpipe/cloudpickle/issues/537
Pic:
Right now, if you try to pickle a jaxpr, you are given an error.
However, there isn't much that needs to change. Simply mapping the
jaxpr.eqns
tonew_jaxpr_eqn
that doesn't pass its source info allows the jaxpr to be pickleable.