Open shenzebang opened 3 years ago
One issue I see in the code is mixing pure JAX function transformation (jax.vmap and jvp) with Objax object-oriented approach. Usually this either does not work at all or works incorrectly.
Another issue - why do you need to do manual assignment of StateVars in variable collection (like critic.vars().subset(is_a=StateVar).assign(svars_t)
)?
So I would suggest to refactor code, so all of the code is using Objax wrappers instead of pure JAX transformations.
So jax.vmap
should be replaced with objax.Vectorize
. We don't have build-in replacement for jvp
, but it should be easy to add similar to objax.GradValues
Hi Alexey,
Thank you for the information. How should I compute the gradient w.r.t. the input of the module using objax.Grad? E.g. I want to compute the Jacobian vector product $\nabla_x(v \nabla_x f(t, x) v)$. Here f is a module that takes an n dimensional vector and time $t$ as input and output an n dimensional vector. $v$ is a fixed n dimensional vector.
I tried the following code (net is the function f)
x_vc = objax.VarCollection()
x_vc["x"] = objax.StateVar(x)
def _net_v_prod(_v):
return jnp.dot(net(_t, x_vc["x"], training=False), _v)
d_net_v_prod = objax.Grad(_net_v_prod, x_vc)
@objax.Function.with_vars(x_vc + d_net_v_prod.vars())
def div_net_fn(_v):
return jnp.dot(d_net_v_prod(_v), _v)
grad_div_fn = objax.Grad(div_net_fn, x_vc)
dscore_1 = - grad_div_fn(v)
but I receive this error:
Traceback (most recent call last): File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 450, in
critic = train_critic(key) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 239, in train_critic _loss = train_op(data, subkey) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 258, in call output, changes = self._call(self.vc.tensors(), kwargs, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss out_flat = xla.xla_call( File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, *params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 600, in process_call return primitive.impl(f, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 576, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun ans = call(fun, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 248, in jit return f(args, kwargs), self.vc.tensors() File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(*args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 221, in train_op g, v = critic_gv_fn(x, key) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 206, in critic_gv_fn result = odeint(ode_func, state_init, tspace, atol=tolerance, rtol=tolerance) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint converted, consts = custom_derivatives.closure_convert(func, y0, t[0], args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/custom_derivatives.py", line 825, in closure_convert return _closure_convert_for_avals(fun, in_tree, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 185, in wrapper return cached(bool(config.x64_enabled), args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 178, in cached return f(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/custom_derivatives.py", line 830, in _closure_convert_for_avals jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1178, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 179, in ode_func dscore_1 = - grad_div_fn(v) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 121, in call return super().call(*args, kwargs)[0] File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 82, in call g, (outputs, changes) = self._call(inputs + self.vc.subset(TrainVar).tensors(), File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 752, in grad_faux (, aux), g = value_and_grad_f(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 811, in value_and_grad_f ans, vjp_py, aux = _vjp(f_partial, dyn_args, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 1882, in _vjp out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, primals, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 58, in f_func outputs = f(list_args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(*args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 167, in div_net_fn return jnp.dot(d_net_v_prod(_v), _v) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 121, in call return super().call(*args, *kwargs)[0] File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 82, in call g, (outputs, changes) = self._call(inputs + self.vc.subset(TrainVar).tensors(), File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 752, in grad_faux (, aux), g = value_and_grad_f(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 811, in value_and_grad_f ans, vjp_py, aux = _vjp(f_partial, dyn_args, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 1882, in _vjp out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, primals, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 58, in f_func outputs = f(*list_args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 161, in _net_v_prod return jnp.dot(net(_t, x_vc["x"], training=False), _v) File "/home/zebang/PycharmProjects/ode_diffusion_jax/model/neural_ode_model.py", line 111, in call h1 = self.conv1(x) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/nn/layers.py", line 198, in call y = lax.conv_general_dilated(x, self.w.value, self.strides, self.padding, File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 625, in conv_general_dilated return conv_general_dilated_p.bind( File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 258, in bind tracers = map(top_trace.full_raise, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map return list(map(f, *args)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 365, in full_raise return self.pure(val) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1012, in new_const aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 927, in get_aval return concrete_aval(x) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 918, in concrete_aval return concrete_aval(x.__jax_array__()) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 919, in concrete_aval raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " TypeError: Value Traced<ShapedArray(float32[32,1,28,28])>with<DynamicJaxprTrace(level=1/1)> with type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type Process finished with exit code 1
objax.Grad(div_net_fn, x_vc)
is a correct way to compute gradients of div_net_fn
with respect to x_vc
The error which you see is cased by something else.
Recently we introduced duck typing of Objax variables which allow to use Objax variables in JAX expressions without typing .value
. However one of the recent changes in JAX broke it so it does not always work at expected. The error which you are seeing is a result of this regression.
Temporary workaround could be to just add .value
every time you access Objax variable.
For example jnp.dot(net(_t, x_vc["x"], training=False), _v)
would change to jnp.dot(net(_t, x_vc["x"].value, training=False), _v)
. I assume here that _t
and _v
are JAX arrays and not Objax variables. If _t
and _v
are Objax variables then you also have to do _t.value
and _v.value
.
Another possible workaround - rollback to earlier version of JAX. Specifically JAX version 0.2.10
should work
Hello everyone, I am new to jax and I am encountering the following problem. Can someone please help me to resolve it?
When I add the StateVar to the vc argument of objax.Jit, it reports
Here is the relevant code: