Open Dxyk opened 2 years ago
Update - I was able to get it to work by jitting the CG solver and marking the A
and M
operators as static.
This is inspired by this discussion https://github.com/google/jax/discussions/9804, but I don't quite understand why marking the functions as static would resolve the TypeError of Zero(ShapedArray())
@jit()
def step():
# Other operations that involve x and b ...
def A_operator(x: jnp.ndarray) -> jnp.ndarray:
# compute res using convolution. Left unchanged
return res
def M_operator(b: jnp.ndarray) -> jnp.ndarray:
# Set as identity matrix
return b
# original version that fails grad with TypeError
#jsp.sparse.linalg.cg(A=A_operator, b=b, M=M_operator))
# new version that works with grad and jit
jax.jit(jsp.sparse.linalg.cg, static_argnames=('A', 'M'))(A=A_operator, b=b, M=M_operator)
Zero(ShapedArray())
is likely related to the fact you are taking a grad
if your original problem just did a jit(step)
without static_argnames
, I wonder if you'd get a clearer error message?
Yes I suspect it is because I'm taking a grad.
In my original problem, the only place I did a jit is jit(step, static_argnums=(n,))
where n
is the index of my simulation parameters that are constants (ints, floats, etc). I didn't think it contributes to the error so I omitted it in the code block above.
By a clearer error message, are you referring to the stack trace? If so, I collapsed it in the issue description. Sorry for the confusion. Here's the full stack trace.
Traceback (most recent call last):
File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 19, in loss_fn
smoke, _ = simulation.simulate(
File ".../FluidSim/sim/simulation.py", line 113, in simulate
smoke, velocity, pressure = step(smoke, velocity, forces, config)
File ".../FluidSim/sim/simulation.py", line 62, in step
velocity, pressure = projection.project(velocity,
File ".../FluidSim/sim/physics/projection.py", line 113, in project
pressure_flat, _ = jsp.sparse.linalg.cg(A=A_operator,
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 287, in cg
return _isolve(_cg_solve,
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 227, in _isolve
x = lax.custom_linear_solve(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 919, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 1001, in value_and_grad_f
g = vjp_py(lax_internal._one(ans))
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__
return self.fun(*args, **kw)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/api.py", line 2364, in _vjp_pullback_wrapper
ans = fun(*args)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in __call__
return self.fun(*args, **kw)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 137, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 232, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 605, in call_transpose
out_flat = primitive.bind(fun, *all_args, **params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
return primitive.impl(f, *tracers, **params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
ans = call(fun, *args)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/ad.py", line 238, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2608, in _linear_solve_transpose_rule
cotangent_b_full = linear_solve_p.bind(
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 2142, in bind
return self.bind_with_trace(top_trace, args, params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/_src/util.py", line 47, in safe_map
return list(map(f, *args))
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 415, in full_raise
return self.pure(val)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1504, in new_const
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1100, in get_aval
return concrete_aval(x)
File ".../FluidSim/venv/lib/python3.9/site-packages/jax/core.py", line 1092, in concrete_aval
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in <module>
grad = loss_fn_grad(smoke, inflow, velocity, forces, time_steps, config,
TypeError: Value Zero(ShapedArray(float32[2000])) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type
Hi,
For what it's worth, I've kinda figured out what was causing the error.
In my original CG solver, when calculating the product of A @ x = b
(A_operator
), I'm reshaping and flattening the x
and b
. This seems to be the cause of the error. Later I realized that the reshaping is unnecessary, and removing these calls made taking grad
of the CG solver work.
@jit()
def step():
# Other operations that involve x and b ...
# A_operator that causes the error.
# x is a (nx * ny,) array, and the function returns a (nx * ny,) array
def A_operator_error(x: jnp.ndarray) -> jnp.ndarray:
x = x.reshape((nx, ny))
# compute res using convolution with reshaping
# ...
res = res.flatten()
return res
# A_operator that works (taking grad does not cause error).
# x is a (nx, ny) array, and the function returns a (nx, ny) array
def A_operator(x: jnp.ndarray) -> jnp.ndarray:
# compute res using convolution without reshaping
# ...
return res
# original version that fails grad with TypeError
jsp.sparse.linalg.cg(A=A_operator, b=b, M=M_operator))
Hello, I have been using Jax to build a differentiable fluid simulator, which involves solving a sparse linear system using the Conjugate Gradient algorithm (
jax.scipy.sparse.linalg.cg
). When I tried to take gradient against the jitted simulator (grad(jit(step))
, wherestep
is one step of the simulation that contains one CG solve), I encountered the following TypeError.Interestingly, without jitting (
grad(step)
only) and keeping everything else the same, the TypeError will not be thrown, and the gradient is computed correctly.When searching for similar issues, I found #1536 with the same error and might be related.
Unfortunately, I was unable to come up with a minimal example that can reproduce this error. If it is necessary, I'd be glad to provide access to my repo.
Here are my package versions
jax==0.3.13
,jaxlib==0.3.10
.Please see the stack trace attached below. Thanks for your help in advance!
Stack Trace
```text Traceback (most recent call last): File ".../FluidSim/demos/simulation_grad_bug_demo.py", line 61, in