Open PhilipVinc opened 1 year ago
After diving deep into jax's internal, I managed to devise this possible fix: Every time you try to construct a Zero(AbstractToken)
build a token instead.
so simply editing this read_cotangent
function to
from jax._src.core import (AbstractToken, Token)
def read_cotangent(v):
if isinstance(v.aval, AbstractToken):
return ct_env.pop(v, Token())
else:
return ct_env.pop(v, Zero(v.aval))
fixes my reproducer above.
Of course this might not be semantically correct but I'm sure you know more about it...
@mattjj if the Issue is not very clear please let me know, and I can try to further clarify it. This is a big blocker for us..
Thanks for the ping! I managed not to notice until just now.
because in general jax.jit functions cannot return them
Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).
As for the main issue, I understand the general outline, but I need to look at mpi4jax more closely, or alternatively set up a toy model, to understand better. I have two gut reactions:
Zero(AbstractToken)
, i.e. in a JVP or transpose rule (not in ad.py's backward_pass), you probably want to instantiate it so that it's no longer symbolic; but in the bigger picture...In particular I don't think the fix in this comment is on the right track, unfortunately.
an example code of how this impact mpi4jax can be had by installing the branch pv/fix-token by for example running pip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token and then using the following MWE:
Where is the token in this example?
Thanks for answering!
In the example I shared above the token is automatically generated by mpi4jax, but let me share an example that is more clear. I hope you don't mind installing mpi4jax (unfortunately tokens are used nowhere in jax so I can't build a reproducer there.
The reproducer is the following:
import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI
def f(a, b):
token_a = jax.lax.create_token()
c, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
d, token_c = mpi4jax.allreduce(b, MPI.SUM, token=token_b)
e = c+d
return d
x = jnp.ones(1)
y = jnp.ones(1)
r = f(x, y)
# jax.make_jaxpr(f)(x, y)
# jax.make_jaxpr(jax.linear_transpose(f, x, y))(r)
jax.linear_transpose(f, x, y)(r)
Let me comment on what is going on in here by inspecting the jaxpr:
In [4]: jax.make_jaxpr(f)(x,y)
Out[4]:
{ lambda ; a:f32[1] b:f32[1]. let
c:Tok = create_token
d:f32[1] e:Tok = allreduce_mpi[
comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f669b50>
op=<mpi4jax._src.utils.HashableMPIType object at 0x13f668fd0>
transpose=False
] a c
f:f32[1] _:Tok = allreduce_mpi[
comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f509710>
op=<mpi4jax._src.utils.HashableMPIType object at 0x13f0e03d0>
transpose=False
] b e
_:f32[1] = add d f
in (f,) }
You can see that I have two calls to the primitive allreduce_mpi
, which is defined in here. This primitive takes two inputs: the array to be reduced and a token to prevent reordering.
Now, what would be the correct transposition of this jaxpr?
I would assume is the execution in reverse of the operations.
The transposition rule is defined here
for master and it is essentially:
def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm,):
_, token = x_args
x_tan, token_tan = tan_args
res, token = mpi_allreduce_transpose_p.bind(
x_tan, token, op=op, comm=comm,
)
return res, token_tan
notice that I bind the primal token instead of the tangent token. Is this correct? It seems not, as this fails with error
File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
1324 if hasattr(x, '__jax_array__'):
1325 return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
1327 "type")
TypeError: Value UndefinedPrimal(AbstractToken()) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type
Another reason suggesting me that I should bind the tangent of the token instead of the primal token here is that I would like to get in the linear transposition an execution order that is reverted, which I only get by binding the tangent token to the tangent primitive. Does this make sense?
So in the branch mpi4jax@pv/fix-token
I tried to modify the transposition rule to read
def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm):
_, _ = x_args
x_tan, token_tan = tan_args
res, token = mpi_allreduce_transpose_p.bind(
x_tan, token_tan, op=op, comm=comm,
)
return res, token
but this fails as well with the error I shared in the original post, namely
File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
1324 if hasattr(x, '__jax_array__'):
1325 return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
1327 "type")
TypeError: Value Zero(AbstractToken()) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
because in general jax.jit functions cannot return them?
Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).
Apparently I was not up to date, and it seems that it is now possible to return tokens (I remember about a year ago it was not possible). However it will still error if you try to transpose a token:
import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI
def f(a):
token_a = jax.lax.create_token()
b, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
return b, token_b
x = jnp.ones(1)
r,s = f(x)
jax.make_jaxpr(f)(x)
jax.make_jaxpr(jax.linear_transpose(f, x))(r)
that fails with
File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/dtypes.py:530, in dtype(x, canonicalize)
528 dt = np.result_type(x)
529 except TypeError as err:
--> 530 raise TypeError(f"Cannot determine dtype of {x}") from err
531 if dt not in _jax_dtype_set and not is_opaque_dtype(dt):
532 raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
533 "type. Only arrays of numeric types are supported by JAX.")
TypeError: Cannot determine dtype of AbstractToken()
Though that's a different bug from what I originally reported and I'm not so worried about this one because tokens usually remain inside the jitted functions...
I hope you don't mind installing mpi4jax
I don't mind at all!
Thanks for the detailed repro and info. I'll take a look...
I haven't had a chance yet :/ I expect I can in the next 48 hours or so.
thanks for the update! looking forwards for a reply
@mattjj any luck? Can I do anything to help you nail this problem down?
@mattjj pretty please 🥹?
Description
In
mpi4jax
we make heavy use of tokens to prevent XLA to reorder our MPI calls, which is particularly a problem on XLA:GPU.A standard function would look something like
and we do not return the token because in general
jax.jit
functions cannot return them, but still, a strong order is enforced within the compiled function because of the tokens.When we transpose our functions with
jax.linear_transpose
, we expect also the transpose function to enforce a strong ordering in the reverse order, for example:However, when mpi4jax attempts to bind the transposed token
Zero(AbstractToken)
an error is raised saying that it cannot be binded because XLA does not know how to represent it.I suspect that the correct way to treat the
Zero(AbstractToken)
should be exactly like a standard token, such thattc_t
in the example above isZero(AbstractToken)
you follow the same path as when binding a normal tokenan example code of how this impact mpi4jax can be had by installing the branch
pv/fix-token
by for example runningpip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token
and then using the following MWE:which raises the error:
What jax/jaxlib version are you using?
jax 0.4.11 jaxlib 0.4.11
Which accelerator(s) are you using?
CPU
Additional system info
MacOs M1
NVIDIA GPU info
No response