jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Regression bug in associative_scan #5164

Closed AdrienCorenflos closed 3 years ago

AdrienCorenflos commented 3 years ago

Hi,

There has been a regression in JAX code, which I believe to be related to this commit.

I have created PR https://github.com/google/jax/pull/5165 with a reproducing failing test.

I will try and look further into what the problem could be, but if the person who made said commit could also have a look it would be great (@hawkinsp I believe this is you).

Adrien

mattjj commented 3 years ago

Thanks for raising this!

What's the nature of the regression? An error, or a performance regression, or something else?

hawkinsp commented 3 years ago

The error appears to be:

FAILED tests/lax_control_flow_test.py::LaxControlFlowTest::testAssociativeScanFailing_2 - TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.

with traceback:

Traceback (most recent call last):
  File "/Users/phawkins/.pyenv/versions/py3.9.0/lib/python3.9/site-packages/absl/testing/parameterized.py", line 282, in bound_param_test
    return test_method(self, **testcase_params)
  File "/Users/phawkins/t/issue5164/tests/lax_control_flow_test.py", line 2488, in testAssociativeScanFailing
    _ = lax.associative_scan(fn, elems=(ms, vs))
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2492, in associative_scan
    scans = _scan(elems_flat)
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2476, in _scan
    even_elems = combine(
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2431, in combine
    c = fn(a, b)
  File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/api.py", line 1197, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 35, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/phawkins/t/issue5164/tests/lax_control_flow_test.py", line 2486, in fn
    return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
  File "/Users/phawkins/p/jax/jax/_src/scipy/linalg.py", line 193, in solve
    return _solve(a, b, sym_pos, lower)
  File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/api.py", line 382, in f_jitted
    return cpp_jitted_f(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/api.py", line 278, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 598, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 251, in memoized_fun
    ans = call(fun, *args)
  File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 645, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/phawkins/p/jax/jax/_src/scipy/linalg.py", line 168, in _solve
    return np_linalg.solve(a, b)
  File "/Users/phawkins/p/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/api.py", line 382, in f_jitted
    return cpp_jitted_f(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/api.py", line 278, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1085, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/phawkins/p/jax/jax/_src/numpy/linalg.py", line 450, in solve
    return lax_linalg._solve(a, b)
  File "/Users/phawkins/p/jax/jax/_src/lax/linalg.py", line 263, in _solve
    return custom_solve(b)
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2211, in custom_linear_solve
    out_flat = linear_solve_p.bind(
  File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 149, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 2296, in _linear_solve_batching_rule
    solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 411, in batch_jaxpr
    jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1201, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/phawkins/p/jax/jax/core.py", line 141, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/Users/phawkins/p/jax/jax/core.py", line 352, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 163, in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1229, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1220, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 1232, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1085, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/phawkins/p/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/phawkins/p/jax/jax/core.py", line 352, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
  File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/batching.py", line 149, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
  File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 4324, in _gather_batching_rule
    return gather(operand, start_indices, dimension_numbers=dnums,
  File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 868, in gather
    return gather_p.bind(
  File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/phawkins/p/jax/jax/interpreters/partial_eval.py", line 1073, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)
  File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 1989, in standard_abstract_eval
    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
  File "/Users/phawkins/p/jax/jax/_src/lax/lax.py", line 4223, in _gather_shape_rule
    raise TypeError(f"Slice size at index {i} in gather op is out of range, "
TypeError: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.

I didn't yet have time to try to decode what this means.

AdrienCorenflos commented 3 years ago

Hi,

It's an error. I have some code which implements Kalman filtering using prefix sum operations (see this). It used to work like a charm, but I've come back to it recently to see it fail a bit randomly (it's related to batching it seems). I've narrowed it down to the combination of associative_scan and lu permutations, hence the test added in the PR.

Adrien

AdrienCorenflos commented 3 years ago

@hawkinsp thanks, if you disable jit in your env you'll see that the stack stops at the lu solving bit. I'm not 100% sure if its the LU that's to blame or associative scan but looking at the file history it seemed more plausible for it to be the associative_scan

AdrienCorenflos commented 3 years ago

@hawkinsp @mattjj by the way, this is why I was using prefix sums:

https://twitter.com/simosarkka/status/1347521322486812675

I also patted myself in the back in the JAX discussions tab (what's the point of doing something if you can't brag about it :smile:) https://github.com/google/jax/discussions/5353

mattjj commented 3 years ago

So awesome! Thanks for sharing that. I used to work on Bayesian smoothing methods, and I own both of Prof Sarkka's books (Bayesian Filtering and Smoothing, and his recent one on SDEs)!