jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

new version breaks arange? #5186

Closed benjaminpope closed 3 years ago

benjaminpope commented 3 years ago

Hi Jax team,

Working with @alipwong on this.

Just upgraded from jax 0.1.75 & jaxlib 0.1.52 where everything was working fine, to jax-0.2.7 and jaxlib-0.1.57.

I have a longer program using jax.np.arange that used to work just fine. Now it breaks and I can't understand what changed.

Rather than the full thing, here is an example simple code

def test_arange(n):
    return np.arange(0,n,1)

test = jit(test_arange)

test(5)

returns errors

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>

----> 6 test(5)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

<ipython-input-27-613700c58fec> in test_arange(n)
      2 #     n = x.shape[0]
----> 3     return np.arange(0,n,1)
      4 

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-27-613700c58fec> in <module>
      6 
      7 # test(np.arange(5))
----> 8 test(5)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    369     c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
    370     xla_consts = map(partial(xb.constant, c), consts)
--> 371     xla_args = xla._xla_callable_args(c, avals, tuple_args)
    372     outs = xla.jaxpr_subcomp(
    373         c, jaxpr, backend, axis_env_, xla_consts,

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in cache_miss(*args, **kwargs)
    282     sine.4 = f32[] sine(cosine.3)
    283     ROOT tuple.5 = (f32[]) tuple(sine.4)
--> 284   }
    285   <BLANKLINE>
    286   <BLANKLINE>

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1227     else:
   1228       msg, = e.args
-> 1229       jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
   1230     msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
   1231     raise JaxprTypeError(msg) from None

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1218 
   1219   Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
-> 1220   """
   1221   try:
   1222     _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1230     msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
   1231     raise JaxprTypeError(msg) from None
-> 1232 
   1233 def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
   1234 

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    596 
    597   def __hash__(self) -> int:
--> 598     return hash((self.level, self.trace_type))
    599 
    600   def __eq__(self, other: object) -> bool:

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    568     else:
    569       assert s.is_tuple()
--> 570       for i, sub in enumerate(s.tuple_shapes()):
    571         subindex = index + (i,)
    572         if sub.is_tuple():

~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    249     # store 1 was occupied, so let's check store 2 is not occupied
    250     try:
--> 251       out2 = aux2()
    252     except StoreException:
    253       return True, out1

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    643   out_nodes = jaxpr_subcomp(
    644       c, jaxpr, backend, AxisEnv(nreps, (), (), None), xla_consts,
--> 645       extend_name_stack(wrap_name(name, 'jit')), *xla_args)
    646   out_tuple = xops.Tuple(c, out_nodes)
    647   backend = xb.get_backend(backend)

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)

~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    158 
    159     return ans
--> 160 
    161   def __repr__(self):
    162     def transform_to_str(x):

<ipython-input-27-613700c58fec> in test_arange(n)
      1 def test_arange(n):
      2 #     n = x.shape[0]
----> 3     return np.arange(0,n,1)
      4 
      5 test = jit(test_arange)

~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in concrete_or_error(force, val, context)
    920   @property
    921   def shape(self):
--> 922     msg = ("UnshapedArray has no shape. Please open an issue at "
    923            "https://github.com/google/jax/issues because it's unexpected for "
    924            "UnshapedArray instances to ever be produced.")

~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in raise_concretization_error(val, context)
    897       complex, "Try using `x.astype(complex)` instead.")
    898   _hex     = concretization_function_error(hex)
--> 899   _oct     = concretization_function_error(oct)
    900 
    901   def at_least_vspace(self) -> AbstractValue:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function test_arange at <ipython-input-27-613700c58fec>:1, this concrete value was not available in Python because it depends on the value of the arguments to test_arange at <ipython-input-27-613700c58fec>:1 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Is there a safe way to jit np.arange in the new version of Jax?

jakevdp commented 3 years ago

Yes, you can use static_argnums to jit-compile arange:

test = jit(test_arange, static_argnums=0)

Does that answer your question?

jakevdp commented 3 years ago

Also, I tried your code snippet with jax 0.1.75 & jaxlib 0.1.52, and it raises the same Concretization error as in the current release, so the issue in your full code might be something different.

jakevdp commented 3 years ago

Here's an example of an arange snippet that would execute in old versions, and leads to a concretization error now:

from jax import jit
import jax.numpy as jnp
import numpy as np

def test_arange():
  n = jnp.add(2, 3)
  return jnp.arange(n)

test = jit(test_arange)

test()

In this case, the issue is due to omnistaging (see #3370). Previously, jnp.add(2, 3) would be computed statically at trace time because the inputs are static; with omnistaging, every jnp operation is staged for compilation, even if the inputs happen to be static.

The solution in a case like this would be to replace jnp.add() with np.add(), as numpy operations will not be staged.

benjaminpope commented 3 years ago

Hi @jakevdp,

Thanks for getting to this so soon.

An example of something that will break in the new version is https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo.ipynb

I can't just do static_argnums because I'm jitting a larger thing of which the arange is a small part. It actually will vary.

And we also need to do gradients - so I don't imagine original numpy will be satisfactory here?

Apologiese,

Ben

jakevdp commented 3 years ago

I think our comments crossed – try to see if you're using any jax.numpy operations on static values, and then feeding the results into arange. I suspect that's what's causing your issue.

benjaminpope commented 3 years ago

Here's the function where it breaks - first no jit function, then we jit it with static_argnums = 2. Dies at Xs = (1.0*np.arange(npupX) - (npupX) / 2.0 + 0.5) * dX. Would this be better with jax.np.shape or something?

def minimal_dft_prim(plane, nlamD, npix):
    """Perform a matrix discrete Fourier transform with selectable
    output sampling and centering.

    Where parameters can be supplied as either scalars or 2-tuples, the first
    element of the 2-tuple is used for the Y dimension and the second for the
    X dimension. This ordering matches that of numpy.ndarray.shape attributes
    and that of Python indexing.

    To achieve exact correspondence to the FFT set nlamD and npix to the size
    of the input array in pixels and use 'FFTSTYLE' centering. (n.b. When
    using `numpy.fft.fft2` you must `numpy.fft.fftshift` the input pupil both
    before and after applying fft2 or else it will introduce a checkerboard
    pattern in the signs of alternating pixels!)

    Parameters
    ----------
    plane : 2D ndarray
        2D array (either real or complex) representing the input image plane or
        pupil plane to transform.
    nlamD : float or 2-tuple of floats (nlamDY, nlamDX)
        Size of desired output region in lambda / D units, assuming that the
        pupil fills the input array (corresponds to 'm' in
        Soummer et al. 2007 4.2). This is in units of the spatial frequency that
        is just Nyquist sampled by the input array.) If given as a tuple,
        interpreted as (nlamDY, nlamDX).
    npix : int or 2-tuple of ints (npixY, npixX)
        Number of pixels per side side of destination plane array (corresponds
        to 'N_B' in Soummer et al. 2007 4.2). This will be the # of pixels in
        the image plane for a forward transformation, in the pupil plane for an
        inverse. If given as a tuple, interpreted as (npixY, npixX).
    """

    npupY, npupX = plane.shape # 32, be careful

    npixY, npixX = 1.0*npix, 1.0*npix

    nlamDY, nlamDX = 1.0*nlamD, 1.0*nlamD

    dU = nlamDX / (npixX)
    dV = nlamDY / (npixY)
    dX = 1.0 / (1.0*npupX)
    dY = 1.0 / (1.0*npupY)

    Xs = (1.0*np.arange(npupX) - (npupX) / 2.0 + 0.5) * dX
    Ys = (1.0*np.arange(npupY) - (npupY) / 2.0 + 0.5) * dY

    Us = (1.0*np.arange(npixX) - (npixX) / 2.0 + 0.5) * dU
    Vs = (1.0*np.arange(npixY) - (npixY) / 2.0 + 0.5) * dV

    XU = np.outer(Xs, Us)
    YV = np.outer(Ys, Vs)

    expXU = np.exp(-2.0 * np.pi * 1j * XU)
    expYV = np.exp(-2.0 * np.pi * 1j * YV).T
    t1 = np.dot(expYV, plane)
    t2 = np.dot(t1, expXU)

    norm_coeff = np.sqrt((nlamDY * nlamDX) / (npupY * npupX * npixY * npixX))
    return norm_coeff * t2

minimal_dft = jit(minimal_dft_prim,static_argnums=2)
jakevdp commented 3 years ago

Can you provide a self-contained snippet that reproduces the error you're seeing? I tried executing this function above with code like the following, but was unable to reproduce any error, either in old or new JAX versions:

plane = np.ones((4, 5))
nlamD = 1
npix = 1

minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
mattjj commented 3 years ago

I'm going to close this issue due to radio silence, but re-open if needed.

From the description, my strong suspicion is that this was due to omnistaging (which happened in the 0.2.0 version bump), and to fix the issue the jnp.arange in question should be replaced with an np.arange (meaning import numpy as np).

benjaminpope commented 3 years ago

Hi @mattjj & @jakevdp,

Sorry for taking so long to reopen this. Thanks for all your help. I've had to do a lot of teaching this semester and haven't had the chance to sit down and figure out what's wrong.

I've taken your advice and made some jax.numpy calls into regular numpy calls - though I'm concerned that when I finally get it working it won't be differentiable anymore? How does this work?

Regardless, pushing forward - here is an example notebook where it breaks and I can't decide how to progress:

https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo_broken.ipynb

I get the error

FilteredStackTrace: IndexError: Array boolean indices must be concrete.

when it tries to do

    200     r = np.sqrt( (xarray-xcenter)**2 + (yarray-ycenter)**2)
--> 201     array = index_update(array,r < radius,fillvalue)
    202 

The issue is that I make an array with a binary mask - this is an essential bit of the code, and it has to be differentiable and ideally jit-able, as it's how we define the shape of telescope components.

But from reading the piece on omnistaging I think it doesn't like boolean masks but I don't know what the preferred approach is?

jakevdp commented 3 years ago

This error doesn't have to do with omnistaging – it has to do with dynamic array shapes not being allowed within JIT. But often with boolean index issues, the operation can be re-expressed in a JIT-compatible way. For example here, rather than

array = index_update(array,r < radius,fillvalue)

you can use the three-term where function:

array = jnp.where(r < radius, fill_value, array)

so long as fill_value is a scalar, the latter does effectively the same thing, but in a way that is JIT-compatible.

jakevdp commented 3 years ago

6186 improves this error & adds a link to suggestions for how to proceed.

benjaminpope commented 3 years ago

Hi @jakevdp,

Super helpful, thanks. Got the boolean masks working ok now.

Hit another hurdle - jit through logic.

https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo_broken.ipynb

It is choking on statements like

if np.all((x==0)): return x

and

if np.abs(pixscale -1.0) > 0.01:
    import warnings
    warnings.warn('filled_circle_aa may not yield exact results for grey pixels when pixel scale <1')

and it didn't seem to have too much trouble with this before. How do we incorporate logic in jitted functions?

benjaminpope commented 3 years ago

PS another thing - as you noted above it's ok when you run

plane = np.ones((4, 5))
nlamD = 1
npix = 1

minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)

but if you have npix is DeviceArray it breaks:


plane = np.ones((4, 5))
nlamD = 1
npix = np.array(1)

minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-15-f939324c9ede> in <module>
      6 # DeviceArray([[4.472136+0.j]], dtype=complex64)
----> 7 test_jit = minimal_dft(plane, nlamD, npix)
      8 # DeviceArray([[4.472136+0.j]], dtype=complex64)

FilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.xla._DeviceArray'> for function minimal_dft_prim is non-hashable.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-15-f939324c9ede> in <module>
      5 test_nojit = minimal_dft_prim(plane, nlamD, npix)
      6 # DeviceArray([[4.472136+0.j]], dtype=complex64)
----> 7 test_jit = minimal_dft(plane, nlamD, npix)
      8 # DeviceArray([[4.472136+0.j]], dtype=complex64)

~/opt/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

~/opt/anaconda3/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    396       return cpp_jitted_f(*args, **kwargs)
    397     else:
--> 398       return cpp_jitted_f(context, *args, **kwargs)
    399   f_jitted._cpp_jitted_f = cpp_jitted_f
    400 

~/opt/anaconda3/lib/python3.8/site-packages/jax/api.py in cache_miss(_, *args, **kwargs)
    275     f = lu.wrap_init(fun)
    276     if static_argnums:
--> 277       f, dyn_args = argnums_partial_except(f, static_argnums, args)
    278     else:
    279       dyn_args = args

~/opt/anaconda3/lib/python3.8/site-packages/jax/api_util.py in argnums_partial_except(f, static_argnums, args)
     98       hash(static_arg)
     99     except TypeError:
--> 100       raise ValueError(
    101           "Non-hashable static arguments are not supported, as this can lead "
    102           f"to unexpected cache-misses. Static argument (index {i}) of type "

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.xla._DeviceArray'> for function minimal_dft_prim is non-hashable.```
jakevdp commented 3 years ago

The reason this breaks for a device array is because you have declared that argument to be static, and static arguments must be hashable. Neither numpy arrays nor JAX arrays are hashable, so they cannot be passed as static arguments.

For example:

from jax import jit
import jax.numpy as jnp
import numpy as np

f = jit(lambda x: x, static_argnums=0)

f(np.array(1))
# ValueError: Non-hashable static arguments are not supported.

f(jnp.array(1))
# ValueError: Non-hashable static arguments are not supported.
benjaminpope commented 3 years ago

I think what I don’t understand is why this used to work? Has this changed?

On Wed, 24 Mar 2021 at 11:58 pm, Jake Vanderplas @.***> wrote:

The reason this breaks for a device array is because you have declared that argument to be static, and static arguments must be hashable. Neither numpy arrays nor JAX arrays are hashable, so they cannot be passed as static arguments.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5186#issuecomment-805844777, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABN6YFO37HCL4HOYU5ENJJ3TFHVZ3ANCNFSM4U2CLQBA .

-- Dr Benjamin Pope NASA Sagan Fellow Center for Cosmology and Particle Physics // Center for Data Science New York University benjaminpope.github.io

jakevdp commented 3 years ago

I think what I don’t understand is why this used to work? Has this changed?

I don't think so... I think arguments marked as static have always been required to be hashable. Can you show a minimal example of something like this that used to work and now doesn't?

hawkinsp commented 3 years ago

I think this changed in https://github.com/google/jax/pull/4717 . See the description of that PR for details as to why we changed this.

jakevdp commented 3 years ago

Thanks Peter, I forgot about that change.

@benjaminpope - about your other question above: In JIT-compiled functions, Python control flow can only depend on static quantities such as explicitly static variables or array shapes and sizes. The reason your examples fail under jit is because the control flow depends on the values within a traced array.

Control flow dependent on traced values has never been allowed in JAX, but what changed in version 0.2.0 is that all JAX operations are staged (the "omni" in omnistaging), where previously operations were staged only if the inputs were traced. This means that a function like this would compile successfully in v0.1.x, but not in v0.2.x:

import numpy as np
import jax.numpy as jnp
from jax import jit

@jit
def f():
  x = np.arange(10)
  if jnp.any(x == 0):
    return

This is because before omnistaging jnp.any would not be staged as it is acting on a static array, and would instead be computed statically at compile-time. After omnistaging, jnp.any would be staged, because every JAX operation is staged.

If you want an operation like this to happen at compile-time, the recommended approach is to use a numpy operation rather than a jax.numpy operation. Thus this version will compile successfully in both 0.1.x and 0.2.x:

@jit
def f():
  x = np.arange(10)
  if np.any(x == 0):
    return

(note using np.any rather than jnp.any).

Tracing, static variables, and a mental model of JAX JIT compilation is all summarized in the How to think in JAX section of the documentation. Let me know if you have any suggestions!

benjaminpope commented 3 years ago

Right, ok, I really appreciate the explanation.

I'm hoping to use both jit compilation and gradients - everything worked fine on the 0.1.x. The goal of this project is to fork the existing optical simulation package poppy and keep the same API and range of functionality, while providing autodiff and jit.

I didn't realize we were taking a hit on the jit compilation bits with staging unhashable things, but it certainly ran about ~4x faster than the existing state of the art using numexpr & numba in poppy. (https://github.com/benjaminpope/morphine/blob/a6acf3cf18a630465b351d8ca678de0c2b54adc0/notebooks/jit_demo.ipynb)

Control flow involving if statements and operations involving binary masks are pretty unavoidable in building optical simulations. Before I bite the bullet and rewrite everything, will it still be differentiable and jittable with np rather than jnp?

jakevdp commented 3 years ago

will it still be differentiable and jittable with np rather than jnp?

No, only traced operations are jittable and differentiable, and operations performed with np rather than jnp will not be traced by design. In 0.1.x, some jnp operations were not traced. In 0.2.x, all jnp operations are traced.

I'm a bit confused by your questions, though, because value-dependent if statements and boolean masking operations have never been traceable, and have therefore never been affected by JIT or compatible with autodiff, even in version 0.1.x. I wonder if in 0.1.x all of these operations in your code were simply being statically computed?

benjaminpope commented 3 years ago

I suppose they must have been. In practice many of these if-statements catch exceptions, and will all go one way in the normal course of (say) optimizing an optical system. The derivatives computed with this code were exactly as expected analytically and (where not possible analytically) heuristically, and optimizations using these gradients converge to excellent solutions. So I think in practice it worked just fine... I just don't know what to do to make things Jax compatible in 0.2.x.