google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.53k stars 2.69k forks source link

Cannot bind to primitive Zero(AbstractToken()) #16303

Open PhilipVinc opened 1 year ago

PhilipVinc commented 1 year ago

Description

In mpi4jax we make heavy use of tokens to prevent XLA to reorder our MPI calls, which is particularly a problem on XLA:GPU.

A standard function would look something like

f(a):
  ta = create_token()
  b, tb = mpi_fun(a, token=ta)
  c, tc = mpi_fun2(a, token=tb)
  d = b+c
  return d

and we do not return the token because in general jax.jit functions cannot return them, but still, a strong order is enforced within the compiled function because of the tokens.

When we transpose our functions with jax.linear_transpose, we expect also the transpose function to enforce a strong ordering in the reverse order, for example:

f_t = jax.linear_transpose(f)
# should match roughly
def f_t(d_t):
  b_t = d_t , c_t = d_t
  tc_t = Zero(AbstractToken) # automatically there because we don-t return tc
  a_t, tb_t = mpi_fun2_transpose(c_t, token=tc_t)
  a_t2, ta_t = mpi_fun(b_t, token= tb_t)
  _ = create_token_transpose(ta_t) 
  return a_t + a_t2

However, when mpi4jax attempts to bind the transposed token Zero(AbstractToken) an error is raised saying that it cannot be binded because XLA does not know how to represent it.

I suspect that the correct way to treat the Zero(AbstractToken) should be exactly like a standard token, such that tc_t in the example above is

an example code of how this impact mpi4jax can be had by installing the branch pv/fix-token by for example running pip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token and then using the following MWE:

from mpi4py import MPI

import jax
import jax.numpy as jnp

import numpy as np

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

from mpi4jax import allreduce

arr = jnp.ones((3, 2))
_arr = arr.copy()

def f(x):
    (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(x)
    return res

res = jax.jit(f)(arr)

which raises the error:

TypeError: Argument 'Zero(AbstractToken())' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type
Open for the whole stack trace ```python ~/Dropbox/Ricerca/Codes/Python/mpi4jax pv/fix-token 35s python-3.11.2 ❯ python ex.py Traceback (most recent call last): File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 39, in arr = jnp.ones((3, 2)) File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line -1, in test_allreduce_transpose File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line -1, in File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line -1, in allreduce jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Argument 'Zero(AbstractToken())' of type '' is not a valid JAX type The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 22, in test_allreduce_transpose() File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 19, in test_allreduce_transpose (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/api.py", line 2270, in transposed_fun in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass cts_out = get_primitive_transpose(eqn.primitive)( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line 209, in mpi_allreduce_transpose_rule res, token = mpi_allreduce_p.bind( ^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args]) ^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 102, in arg_spec aval = xla.abstractify(x) ^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/xla.py", line 200, in abstractify raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument 'Zero(AbstractToken())' of type '' is not a valid JAX type The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 22, in test_allreduce_transpose() File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 19, in test_allreduce_transpose (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line 209, in mpi_allreduce_transpose_rule res, token = mpi_allreduce_p.bind( ^^^^^^^^^^^^^^^^^^^^^ TypeError: Argument 'Zero(AbstractToken())' of type '' is not a valid JAX type ```

What jax/jaxlib version are you using?

jax 0.4.11 jaxlib 0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

MacOs M1

NVIDIA GPU info

No response

PhilipVinc commented 1 year ago

After diving deep into jax's internal, I managed to devise this possible fix: Every time you try to construct a Zero(AbstractToken) build a token instead.

so simply editing this read_cotangent function to

  from jax._src.core import (AbstractToken, Token)

  def read_cotangent(v):
    if isinstance(v.aval, AbstractToken):
      return ct_env.pop(v, Token())
    else:
      return ct_env.pop(v, Zero(v.aval))

fixes my reproducer above.

Of course this might not be semantically correct but I'm sure you know more about it...

PhilipVinc commented 1 year ago

@mattjj if the Issue is not very clear please let me know, and I can try to further clarify it. This is a big blocker for us..

mattjj commented 1 year ago

Thanks for the ping! I managed not to notice until just now.

because in general jax.jit functions cannot return them

Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).

As for the main issue, I understand the general outline, but I need to look at mpi4jax more closely, or alternatively set up a toy model, to understand better. I have two gut reactions:

  1. taking a narrow pigeon-holed view, anywhere you see a symbolic zero Zero(AbstractToken), i.e. in a JVP or transpose rule (not in ad.py's backward_pass), you probably want to instantiate it so that it's no longer symbolic; but in the bigger picture...
  2. I don't think we want to rely on tangents-of-tokens to be token-like at all, since throughout JAX's AD system we assume tangent types are vector-space-like, in particular in that they have zero elements which have the behavior that any linear function applied to them is zero.

In particular I don't think the fix in this comment is on the right track, unfortunately.

an example code of how this impact mpi4jax can be had by installing the branch pv/fix-token by for example running pip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token and then using the following MWE:

Where is the token in this example?

PhilipVinc commented 1 year ago

Thanks for answering!

In the example I shared above the token is automatically generated by mpi4jax, but let me share an example that is more clear. I hope you don't mind installing mpi4jax (unfortunately tokens are used nowhere in jax so I can't build a reproducer there.

The reproducer is the following:

import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI

def f(a, b):
    token_a = jax.lax.create_token()
    c, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
    d, token_c = mpi4jax.allreduce(b, MPI.SUM, token=token_b)
    e = c+d
    return d

x = jnp.ones(1)
y = jnp.ones(1)
r = f(x, y)

# jax.make_jaxpr(f)(x, y)

# jax.make_jaxpr(jax.linear_transpose(f, x, y))(r)
jax.linear_transpose(f, x, y)(r)

Let me comment on what is going on in here by inspecting the jaxpr:

In [4]: jax.make_jaxpr(f)(x,y)
Out[4]:
{ lambda ; a:f32[1] b:f32[1]. let
    c:Tok = create_token
    d:f32[1] e:Tok = allreduce_mpi[
      comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f669b50>
      op=<mpi4jax._src.utils.HashableMPIType object at 0x13f668fd0>
      transpose=False
    ] a c
    f:f32[1] _:Tok = allreduce_mpi[
      comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f509710>
      op=<mpi4jax._src.utils.HashableMPIType object at 0x13f0e03d0>
      transpose=False
    ] b e
    _:f32[1] = add d f
  in (f,) }

You can see that I have two calls to the primitive allreduce_mpi, which is defined in here. This primitive takes two inputs: the array to be reduced and a token to prevent reordering.

Now, what would be the correct transposition of this jaxpr? I would assume is the execution in reverse of the operations. The transposition rule is defined here for master and it is essentially:

def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm,):
    _, token = x_args
    x_tan, token_tan = tan_args

    res, token = mpi_allreduce_transpose_p.bind(
        x_tan, token, op=op, comm=comm,
    )
    return res, token_tan

notice that I bind the primal token instead of the tangent token. Is this correct? It seems not, as this fails with error

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
   1324 if hasattr(x, '__jax_array__'):
   1325   return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
   1327                  "type")

TypeError: Value UndefinedPrimal(AbstractToken()) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type

Another reason suggesting me that I should bind the tangent of the token instead of the primal token here is that I would like to get in the linear transposition an execution order that is reverted, which I only get by binding the tangent token to the tangent primitive. Does this make sense?

So in the branch mpi4jax@pv/fix-token I tried to modify the transposition rule to read

def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm):
    _, _ = x_args
    x_tan, token_tan = tan_args

    res, token = mpi_allreduce_transpose_p.bind(
        x_tan, token_tan, op=op, comm=comm,
    )
    return res, token

but this fails as well with the error I shared in the original post, namely

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
   1324 if hasattr(x, '__jax_array__'):
   1325   return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
   1327                  "type")

TypeError: Value Zero(AbstractToken()) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
PhilipVinc commented 1 year ago

because in general jax.jit functions cannot return them?

Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).

Apparently I was not up to date, and it seems that it is now possible to return tokens (I remember about a year ago it was not possible). However it will still error if you try to transpose a token:

import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI

def f(a):
    token_a = jax.lax.create_token()
    b, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
    return b, token_b

x = jnp.ones(1)
r,s = f(x)

jax.make_jaxpr(f)(x)

jax.make_jaxpr(jax.linear_transpose(f, x))(r)

that fails with

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/dtypes.py:530, in dtype(x, canonicalize)
    528     dt = np.result_type(x)
    529   except TypeError as err:
--> 530     raise TypeError(f"Cannot determine dtype of {x}") from err
    531 if dt not in _jax_dtype_set and not is_opaque_dtype(dt):
    532   raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
    533                   "type. Only arrays of numeric types are supported by JAX.")

TypeError: Cannot determine dtype of AbstractToken()

Though that's a different bug from what I originally reported and I'm not so worried about this one because tokens usually remain inside the jitted functions...

mattjj commented 1 year ago

I hope you don't mind installing mpi4jax

I don't mind at all!

Thanks for the detailed repro and info. I'll take a look...

mattjj commented 1 year ago

I haven't had a chance yet :/ I expect I can in the next 48 hours or so.

PhilipVinc commented 1 year ago

thanks for the update! looking forwards for a reply

PhilipVinc commented 1 year ago

@mattjj any luck? Can I do anything to help you nail this problem down?

PhilipVinc commented 1 year ago

@mattjj pretty please 🥹?