jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

Taking gradient against jitted CG solver throws TypeError #11213

Open Dxyk opened 2 years ago

Dxyk commented 2 years ago

Hello, I have been using Jax to build a differentiable fluid simulator, which involves solving a sparse linear system using the Conjugate Gradient algorithm (jax.scipy.sparse.linalg.cg). When I tried to take gradient against the jitted simulator (grad(jit(step)), where step is one step of the simulation that contains one CG solve), I encountered the following TypeError.

TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type

Interestingly, without jitting (grad(step) only) and keeping everything else the same, the TypeError will not be thrown, and the gradient is computed correctly.

When searching for similar issues, I found #1536 with the same error and might be related.

Unfortunately, I was unable to come up with a minimal example that can reproduce this error. If it is necessary, I'd be glad to provide access to my repo.

Here are my package versions jax==0.3.13, jaxlib==0.3.10.

Please see the stack trace attached below. Thanks for your help in advance!

Stack Trace ```text Traceback (most recent call last): File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config, File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 19, in loss_fn smoke, _ = simulation.simulate( File ".../FluidSim/sim/simulation.py", line 113, in simulate smoke, velocity, pressure = step(smoke, velocity, forces, config) File ".../FluidSim/sim/simulation.py", line 62, in step velocity, pressure = projection.project(velocity, File ".../FluidSim/sim/physics/projection.py", line 113, in project pressure_flat, _ = jsp.sparse.linalg.cg(A=A_operator, File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 287, in cg return _isolve(_cg_solve, File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 227, in _isolve x = lax.custom_linear_solve( jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Value Zero(ShapedArray(float32[2000])) with type is not a valid JAX type The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config, File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 919, in grad_f _, g = value_and_grad_f(*args, **kwargs) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 1001, in value_and_grad_f g = vjp_py(lax_internal._one(ans)) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__ return self.fun(*args, **kw) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 2364, in _vjp_pullback_wrapper ans = fun(*args) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__ return self.fun(*args, **kw) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 137, in unbound_vjp arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 232, in backward_pass cts_out = get_primitive_transpose(eqn.primitive)( File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 605, in call_transpose out_flat = primitive.bind(fun, *all_args, **params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1765, in bind return call_bind(self, fun, *args, **params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind outs = top_trace.process_call(primitive, fun_, tracers, params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 678, in process_call return primitive.impl(f, *tracers, **params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun ans = call(fun, *args) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False, File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers_) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 238, in backward_pass cts_out = get_primitive_transpose(eqn.primitive)( File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2608, in _linear_solve_transpose_rule cotangent_b_full = linear_solve_p.bind( File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 2142, in bind return self.bind_with_trace(top_trace, args, params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/util.py", line 47, in safe_map return list(map(f, *args)) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 415, in full_raise return self.pure(val) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1504, in new_const aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1100, in get_aval return concrete_aval(x) File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1092, in concrete_aval raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value Zero(ShapedArray(float32[2000])) with type is not a valid JAX type The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config, TypeError: Value Zero(ShapedArray(float32[2000])) with type is not a valid JAX type ```
Dxyk commented 2 years ago

Update - I was able to get it to work by jitting the CG solver and marking the A and M operators as static.

This is inspired by this discussion https://github.com/google/jax/discussions/9804, but I don't quite understand why marking the functions as static would resolve the TypeError of Zero(ShapedArray())

@jit()
def step(): 
    # Other operations that involve x and b ...

    def A_operator(x: jnp.ndarray) -> jnp.ndarray:
        # compute res using convolution. Left unchanged
        return res

    def M_operator(b: jnp.ndarray) -> jnp.ndarray:
        # Set as identity matrix
        return b

    # original version that fails grad with TypeError
    #jsp.sparse.linalg.cg(A=A_operator, b=b, M=M_operator))

    # new version that works with grad and jit
    jax.jit(jsp.sparse.linalg.cg, static_argnames=('A', 'M'))(A=A_operator, b=b, M=M_operator)
zhangqiaorjc commented 2 years ago

Zero(ShapedArray()) is likely related to the fact you are taking a grad

if your original problem just did a jit(step) without static_argnames, I wonder if you'd get a clearer error message?

Dxyk commented 2 years ago

Yes I suspect it is because I'm taking a grad.

In my original problem, the only place I did a jit is jit(step, static_argnums=(n,)) where n is the index of my simulation parameters that are constants (ints, floats, etc). I didn't think it contributes to the error so I omitted it in the code block above.

By a clearer error message, are you referring to the stack trace? If so, I collapsed it in the issue description. Sorry for the confusion. Here's the full stack trace.

Traceback (most recent call last):
  File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
    grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
  File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 19, in loss_fn
    smoke, _ = simulation.simulate(
  File ".../FluidSim/sim/simulation.py", line 113, in simulate
    smoke, velocity, pressure = step(smoke, velocity, forces, config)
  File ".../FluidSim/sim/simulation.py", line 62, in step
    velocity, pressure = projection.project(velocity,
  File ".../FluidSim/sim/physics/projection.py", line 113, in project
    pressure_flat, _ = jsp.sparse.linalg.cg(A=A_operator,
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 287, in cg
    return _isolve(_cg_solve,
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 227, in _isolve
    x = lax.custom_linear_solve(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
    grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 919, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 1001, in value_and_grad_f
    g = vjp_py(lax_internal._one(ans))
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__
    return self.fun(*args, **kw)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 2364, in _vjp_pullback_wrapper
    ans = fun(*args)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__
    return self.fun(*args, **kw)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 137, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 232, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 605, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
    return primitive.impl(f, *tracers, **params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
    ans = call(fun, *args)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 238, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2608, in _linear_solve_transpose_rule
    cotangent_b_full = linear_solve_p.bind(
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 2142, in bind
    return self.bind_with_trace(top_trace, args, params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/util.py", line 47, in safe_map
    return list(map(f, *args))
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 415, in full_raise
    return self.pure(val)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1504, in new_const
    aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1100, in get_aval
    return concrete_aval(x)
  File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1092, in concrete_aval
    raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
    grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
Dxyk commented 2 years ago

Hi,

For what it's worth, I've kinda figured out what was causing the error.

In my original CG solver, when calculating the product of A @ x = b (A_operator), I'm reshaping and flattening the x and b. This seems to be the cause of the error. Later I realized that the reshaping is unnecessary, and removing these calls made taking grad of the CG solver work.

@jit()
def step(): 
    # Other operations that involve x and b ...

    # A_operator that causes the error. 
    # x is a (nx * ny,) array, and the function returns a (nx * ny,) array
    def A_operator_error(x: jnp.ndarray) -> jnp.ndarray:
        x = x.reshape((nx, ny))
        # compute res using convolution with reshaping
        # ...
        res = res.flatten()
        return res

    # A_operator that works (taking grad does not cause error).
    # x is a (nx, ny) array, and the function returns a (nx, ny) array
    def A_operator(x: jnp.ndarray) -> jnp.ndarray:
        # compute res using convolution without reshaping
        # ...
        return res

    # original version that fails grad with TypeError
    jsp.sparse.linalg.cg(A=A_operator, b=b, M=M_operator))