Closed benjaminpope closed 3 years ago
Yes, you can use static_argnums
to jit-compile arange:
test = jit(test_arange, static_argnums=0)
Does that answer your question?
Also, I tried your code snippet with jax 0.1.75 & jaxlib 0.1.52, and it raises the same Concretization error as in the current release, so the issue in your full code might be something different.
Here's an example of an arange
snippet that would execute in old versions, and leads to a concretization error now:
from jax import jit
import jax.numpy as jnp
import numpy as np
def test_arange():
n = jnp.add(2, 3)
return jnp.arange(n)
test = jit(test_arange)
test()
In this case, the issue is due to omnistaging (see #3370). Previously, jnp.add(2, 3)
would be computed statically at trace time because the inputs are static; with omnistaging, every jnp
operation is staged for compilation, even if the inputs happen to be static.
The solution in a case like this would be to replace jnp.add()
with np.add()
, as numpy operations will not be staged.
Hi @jakevdp,
Thanks for getting to this so soon.
An example of something that will break in the new version is https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo.ipynb
I can't just do static_argnums because I'm jitting a larger thing of which the arange is a small part. It actually will vary.
And we also need to do gradients - so I don't imagine original numpy will be satisfactory here?
Apologiese,
Ben
I think our comments crossed – try to see if you're using any jax.numpy
operations on static values, and then feeding the results into arange
. I suspect that's what's causing your issue.
Here's the function where it breaks - first no jit function, then we jit it with static_argnums = 2. Dies at Xs = (1.0*np.arange(npupX) - (npupX) / 2.0 + 0.5) * dX
. Would this be better with jax.np.shape or something?
def minimal_dft_prim(plane, nlamD, npix):
"""Perform a matrix discrete Fourier transform with selectable
output sampling and centering.
Where parameters can be supplied as either scalars or 2-tuples, the first
element of the 2-tuple is used for the Y dimension and the second for the
X dimension. This ordering matches that of numpy.ndarray.shape attributes
and that of Python indexing.
To achieve exact correspondence to the FFT set nlamD and npix to the size
of the input array in pixels and use 'FFTSTYLE' centering. (n.b. When
using `numpy.fft.fft2` you must `numpy.fft.fftshift` the input pupil both
before and after applying fft2 or else it will introduce a checkerboard
pattern in the signs of alternating pixels!)
Parameters
----------
plane : 2D ndarray
2D array (either real or complex) representing the input image plane or
pupil plane to transform.
nlamD : float or 2-tuple of floats (nlamDY, nlamDX)
Size of desired output region in lambda / D units, assuming that the
pupil fills the input array (corresponds to 'm' in
Soummer et al. 2007 4.2). This is in units of the spatial frequency that
is just Nyquist sampled by the input array.) If given as a tuple,
interpreted as (nlamDY, nlamDX).
npix : int or 2-tuple of ints (npixY, npixX)
Number of pixels per side side of destination plane array (corresponds
to 'N_B' in Soummer et al. 2007 4.2). This will be the # of pixels in
the image plane for a forward transformation, in the pupil plane for an
inverse. If given as a tuple, interpreted as (npixY, npixX).
"""
npupY, npupX = plane.shape # 32, be careful
npixY, npixX = 1.0*npix, 1.0*npix
nlamDY, nlamDX = 1.0*nlamD, 1.0*nlamD
dU = nlamDX / (npixX)
dV = nlamDY / (npixY)
dX = 1.0 / (1.0*npupX)
dY = 1.0 / (1.0*npupY)
Xs = (1.0*np.arange(npupX) - (npupX) / 2.0 + 0.5) * dX
Ys = (1.0*np.arange(npupY) - (npupY) / 2.0 + 0.5) * dY
Us = (1.0*np.arange(npixX) - (npixX) / 2.0 + 0.5) * dU
Vs = (1.0*np.arange(npixY) - (npixY) / 2.0 + 0.5) * dV
XU = np.outer(Xs, Us)
YV = np.outer(Ys, Vs)
expXU = np.exp(-2.0 * np.pi * 1j * XU)
expYV = np.exp(-2.0 * np.pi * 1j * YV).T
t1 = np.dot(expYV, plane)
t2 = np.dot(t1, expXU)
norm_coeff = np.sqrt((nlamDY * nlamDX) / (npupY * npupX * npixY * npixX))
return norm_coeff * t2
minimal_dft = jit(minimal_dft_prim,static_argnums=2)
Can you provide a self-contained snippet that reproduces the error you're seeing? I tried executing this function above with code like the following, but was unable to reproduce any error, either in old or new JAX versions:
plane = np.ones((4, 5))
nlamD = 1
npix = 1
minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
I'm going to close this issue due to radio silence, but re-open if needed.
From the description, my strong suspicion is that this was due to omnistaging (which happened in the 0.2.0 version bump), and to fix the issue the jnp.arange
in question should be replaced with an np.arange
(meaning import numpy as np
).
Hi @mattjj & @jakevdp,
Sorry for taking so long to reopen this. Thanks for all your help. I've had to do a lot of teaching this semester and haven't had the chance to sit down and figure out what's wrong.
I've taken your advice and made some jax.numpy calls into regular numpy calls - though I'm concerned that when I finally get it working it won't be differentiable anymore? How does this work?
Regardless, pushing forward - here is an example notebook where it breaks and I can't decide how to progress:
https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo_broken.ipynb
I get the error
FilteredStackTrace: IndexError: Array boolean indices must be concrete.
when it tries to do
200 r = np.sqrt( (xarray-xcenter)**2 + (yarray-ycenter)**2)
--> 201 array = index_update(array,r < radius,fillvalue)
202
The issue is that I make an array with a binary mask - this is an essential bit of the code, and it has to be differentiable and ideally jit-able, as it's how we define the shape of telescope components.
But from reading the piece on omnistaging I think it doesn't like boolean masks but I don't know what the preferred approach is?
This error doesn't have to do with omnistaging – it has to do with dynamic array shapes not being allowed within JIT. But often with boolean index issues, the operation can be re-expressed in a JIT-compatible way. For example here, rather than
array = index_update(array,r < radius,fillvalue)
you can use the three-term where
function:
array = jnp.where(r < radius, fill_value, array)
so long as fill_value
is a scalar, the latter does effectively the same thing, but in a way that is JIT-compatible.
Hi @jakevdp,
Super helpful, thanks. Got the boolean masks working ok now.
Hit another hurdle - jit through logic.
https://github.com/benjaminpope/morphine/blob/stable/notebooks/jit_demo_broken.ipynb
It is choking on statements like
if np.all((x==0)): return x
and
if np.abs(pixscale -1.0) > 0.01:
import warnings
warnings.warn('filled_circle_aa may not yield exact results for grey pixels when pixel scale <1')
and it didn't seem to have too much trouble with this before. How do we incorporate logic in jitted functions?
PS another thing - as you noted above it's ok when you run
plane = np.ones((4, 5))
nlamD = 1
npix = 1
minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
but if you have npix is DeviceArray it breaks:
plane = np.ones((4, 5))
nlamD = 1
npix = np.array(1)
minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
minimal_dft(plane, nlamD, npix)
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-15-f939324c9ede> in <module>
6 # DeviceArray([[4.472136+0.j]], dtype=complex64)
----> 7 test_jit = minimal_dft(plane, nlamD, npix)
8 # DeviceArray([[4.472136+0.j]], dtype=complex64)
FilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.xla._DeviceArray'> for function minimal_dft_prim is non-hashable.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-15-f939324c9ede> in <module>
5 test_nojit = minimal_dft_prim(plane, nlamD, npix)
6 # DeviceArray([[4.472136+0.j]], dtype=complex64)
----> 7 test_jit = minimal_dft(plane, nlamD, npix)
8 # DeviceArray([[4.472136+0.j]], dtype=complex64)
~/opt/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
~/opt/anaconda3/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
396 return cpp_jitted_f(*args, **kwargs)
397 else:
--> 398 return cpp_jitted_f(context, *args, **kwargs)
399 f_jitted._cpp_jitted_f = cpp_jitted_f
400
~/opt/anaconda3/lib/python3.8/site-packages/jax/api.py in cache_miss(_, *args, **kwargs)
275 f = lu.wrap_init(fun)
276 if static_argnums:
--> 277 f, dyn_args = argnums_partial_except(f, static_argnums, args)
278 else:
279 dyn_args = args
~/opt/anaconda3/lib/python3.8/site-packages/jax/api_util.py in argnums_partial_except(f, static_argnums, args)
98 hash(static_arg)
99 except TypeError:
--> 100 raise ValueError(
101 "Non-hashable static arguments are not supported, as this can lead "
102 f"to unexpected cache-misses. Static argument (index {i}) of type "
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.xla._DeviceArray'> for function minimal_dft_prim is non-hashable.```
The reason this breaks for a device array is because you have declared that argument to be static, and static arguments must be hashable. Neither numpy arrays nor JAX arrays are hashable, so they cannot be passed as static arguments.
For example:
from jax import jit
import jax.numpy as jnp
import numpy as np
f = jit(lambda x: x, static_argnums=0)
f(np.array(1))
# ValueError: Non-hashable static arguments are not supported.
f(jnp.array(1))
# ValueError: Non-hashable static arguments are not supported.
I think what I don’t understand is why this used to work? Has this changed?
On Wed, 24 Mar 2021 at 11:58 pm, Jake Vanderplas @.***> wrote:
The reason this breaks for a device array is because you have declared that argument to be static, and static arguments must be hashable. Neither numpy arrays nor JAX arrays are hashable, so they cannot be passed as static arguments.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5186#issuecomment-805844777, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABN6YFO37HCL4HOYU5ENJJ3TFHVZ3ANCNFSM4U2CLQBA .
-- Dr Benjamin Pope NASA Sagan Fellow Center for Cosmology and Particle Physics // Center for Data Science New York University benjaminpope.github.io
I think what I don’t understand is why this used to work? Has this changed?
I don't think so... I think arguments marked as static have always been required to be hashable. Can you show a minimal example of something like this that used to work and now doesn't?
I think this changed in https://github.com/google/jax/pull/4717 . See the description of that PR for details as to why we changed this.
Thanks Peter, I forgot about that change.
@benjaminpope - about your other question above: In JIT-compiled functions, Python control flow can only depend on static quantities such as explicitly static variables or array shapes and sizes. The reason your examples fail under jit is because the control flow depends on the values within a traced array.
Control flow dependent on traced values has never been allowed in JAX, but what changed in version 0.2.0 is that all JAX operations are staged (the "omni" in omnistaging), where previously operations were staged only if the inputs were traced. This means that a function like this would compile successfully in v0.1.x, but not in v0.2.x:
import numpy as np
import jax.numpy as jnp
from jax import jit
@jit
def f():
x = np.arange(10)
if jnp.any(x == 0):
return
This is because before omnistaging jnp.any
would not be staged as it is acting on a static array, and would instead be computed statically at compile-time. After omnistaging, jnp.any
would be staged, because every JAX operation is staged.
If you want an operation like this to happen at compile-time, the recommended approach is to use a numpy
operation rather than a jax.numpy
operation. Thus this version will compile successfully in both 0.1.x and 0.2.x:
@jit
def f():
x = np.arange(10)
if np.any(x == 0):
return
(note using np.any
rather than jnp.any
).
Tracing, static variables, and a mental model of JAX JIT compilation is all summarized in the How to think in JAX section of the documentation. Let me know if you have any suggestions!
Right, ok, I really appreciate the explanation.
I'm hoping to use both jit compilation and gradients - everything worked fine on the 0.1.x. The goal of this project is to fork the existing optical simulation package poppy and keep the same API and range of functionality, while providing autodiff and jit.
I didn't realize we were taking a hit on the jit compilation bits with staging unhashable things, but it certainly ran about ~4x faster than the existing state of the art using numexpr & numba in poppy. (https://github.com/benjaminpope/morphine/blob/a6acf3cf18a630465b351d8ca678de0c2b54adc0/notebooks/jit_demo.ipynb)
Control flow involving if statements and operations involving binary masks are pretty unavoidable in building optical simulations. Before I bite the bullet and rewrite everything, will it still be differentiable and jittable with np rather than jnp?
will it still be differentiable and jittable with np rather than jnp?
No, only traced operations are jittable and differentiable, and operations performed with np
rather than jnp
will not be traced by design. In 0.1.x, some jnp
operations were not traced. In 0.2.x, all jnp
operations are traced.
I'm a bit confused by your questions, though, because value-dependent if statements and boolean masking operations have never been traceable, and have therefore never been affected by JIT or compatible with autodiff, even in version 0.1.x. I wonder if in 0.1.x all of these operations in your code were simply being statically computed?
I suppose they must have been. In practice many of these if-statements catch exceptions, and will all go one way in the normal course of (say) optimizing an optical system. The derivatives computed with this code were exactly as expected analytically and (where not possible analytically) heuristically, and optimizations using these gradients converge to excellent solutions. So I think in practice it worked just fine... I just don't know what to do to make things Jax compatible in 0.2.x.
Hi Jax team,
Working with @alipwong on this.
Just upgraded from jax 0.1.75 & jaxlib 0.1.52 where everything was working fine, to jax-0.2.7 and jaxlib-0.1.57.
I have a longer program using jax.np.arange that used to work just fine. Now it breaks and I can't understand what changed.
Rather than the full thing, here is an example simple code
returns errors
Is there a safe way to jit np.arange in the new version of Jax?