Closed AdrienCorenflos closed 3 years ago
Thanks for raising this!
What's the nature of the regression? An error, or a performance regression, or something else?
The error appears to be:
FAILED tests/lax_control_flow_test.py::LaxControlFlowTest::testAssociativeScanFailing_2 - TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.
with traceback:
Traceback (most recent call last):
File "/Users/phawkins/.pyenv/versions/py3.9.0/lib/python3.9/site-packages/absl/testing/parameterized.py", line 282, in bound_param_test
return test_method(self, **testcase_params)
File "/Users/phawkins/t/issue5164/tests/lax_control_flow_test.py", line 2488, in testAssociativeScanFailing
_ = lax.associative_scan(fn, elems=(ms, vs))
File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2492, in associative_scan
scans = _scan(elems_flat)
File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2476, in _scan
even_elems = combine(
File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2431, in combine
c = fn(a, b)
File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/api.py", line 1197, in batched_fun
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 35, in batch
return batched_fun.call_wrapped(*in_vals)
File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/phawkins/t/issue5164/tests/lax_control_flow_test.py", line 2486, in fn
return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
File "/Users/phawkins/p/jax/jax/_src/scipy/linalg.py", line 193, in solve
return _solve(a, b, sym_pos, lower)
File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/api.py", line 382, in f_jitted
return cpp_jitted_f(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/api.py", line 278, in cache_miss
out_flat = xla.xla_call(
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
vals_out = call_primitive.bind(f, *vals, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 598, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/phawkins/p/jax/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, *args)
File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 645, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/phawkins/p/jax/jax/_src/scipy/linalg.py", line 168, in _solve
return np_linalg.solve(a, b)
File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/api.py", line 382, in f_jitted
return cpp_jitted_f(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/api.py", line 278, in cache_miss
out_flat = xla.xla_call(
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
vals_out = call_primitive.bind(f, *vals, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1085, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/phawkins/p/jax/jax/_src/numpy/linalg.py", line 450, in solve
return lax_linalg._solve(a, b)
File "/Users/phawkins/p/jax/jax/_src/lax/linalg.py", line 263, in _solve
return custom_solve(b)
File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2211, in custom_linear_solve
out_flat = linear_solve_p.bind(
File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 149, in process_primitive
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2296, in _linear_solve_batching_rule
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 411, in batch_jaxpr
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1201, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/phawkins/p/jax/jax/core.py", line 141, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
File "/Users/phawkins/p/jax/jax/core.py", line 352, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
vals_out = call_primitive.bind(f, *vals, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, **params)
File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1085, in process_call
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/phawkins/p/jax/jax/core.py", line 352, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 149, in process_primitive
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 4324, in _gather_batching_rule
return gather(operand, start_indices, dimension_numbers=dnums,
File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 868, in gather
return gather_p.bind(
File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1073, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 1989, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 4223, in _gather_shape_rule
raise TypeError(f"Slice size at index {i} in gather op is out of range, "
TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.
I didn't yet have time to try to decode what this means.
Hi,
It's an error. I have some code which implements Kalman filtering using prefix sum operations (see this). It used to work like a charm, but I've come back to it recently to see it fail a bit randomly (it's related to batching it seems). I've narrowed it down to the combination of associative_scan and lu permutations, hence the test added in the PR.
Adrien
@hawkinsp thanks, if you disable jit in your env you'll see that the stack stops at the lu solving bit. I'm not 100% sure if its the LU that's to blame or associative scan but looking at the file history it seemed more plausible for it to be the associative_scan
@hawkinsp @mattjj by the way, this is why I was using prefix sums:
https://twitter.com/simosarkka/status/1347521322486812675
I also patted myself in the back in the JAX discussions tab (what's the point of doing something if you can't brag about it :smile:) https://github.com/google/jax/discussions/5353
So awesome! Thanks for sharing that. I used to work on Bayesian smoothing methods, and I own both of Prof Sarkka's books (Bayesian Filtering and Smoothing, and his recent one on SDEs)!
Hi,
There has been a regression in JAX code, which I believe to be related to this commit.
I have created PR https://github.com/google/jax/pull/5165 with a reproducing failing test.
I will try and look further into what the problem could be, but if the person who made said commit could also have a look it would be great (@hawkinsp I believe this is you).
Adrien