google / objax

Apache License 2.0
769 stars 77 forks source link

objax.Jit reports error when StateVar is added to the vc argument #219

Open shenzebang opened 3 years ago

shenzebang commented 3 years ago

Hello everyone, I am new to jax and I am encountering the following problem. Can someone please help me to resolve it?

When I add the StateVar to the vc argument of objax.Jit, it reports

File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 232, in train_critic _loss = train_op(data, subkey) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 258, in call output, changes = self._call(self.vc.tensors(), kwargs, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss out_flat = xla.xla_call( File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, *params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 600, in process_call return primitive.impl(f, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 576, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun ans = call(fun, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 179, in call_wrapped ans = gen.send(ans) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1363, in process_env_traces outs = map(trace.full_raise, outs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map return list(map(f, *args)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 381, in full_raise return self.lift(val) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1014, in new_const self.frame.tracers.append(tracer) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1002, in frame return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error IndexError: tuple index out of range

Here is the relevant code:


def train_critic(key):
    key, subkey = random.split(key)
    net = UNet(marginal_prob_std_fn, subkey)
    critic = UNet(marginal_prob_std_fn, subkey)
    objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), net.vars())
    objax.io.load_var_collection(os.path.join(model_path, 'scorenet.npz'), critic.vars())

    def critic_gv_fn(x_init, key):
        # define v for the Hutchinson’s Estimator
        key, subkey = random.split(key)
        v = random.normal(subkey, tuple([20] + list(x_init.shape)))
        # define the initial states
        t_0 = jnp.zeros(x_init.shape[0])
        score_init = net(t_0, x_init, training=False)
        critic_loss_init = jnp.zeros(1)
        critic_grad_init = [jnp.zeros_like(_var) for _var in critic.vars().subset(is_a=TrainVar)]
        state_init = [x_init, score_init, critic_loss_init, critic_grad_init]

        def ode_func(states, t):
            x = states[0]
            score = states[1]
            _t = jnp.ones([x.shape[0]]) * t
            diffusion_weight = diffusion_coeff_fn(t)
            score_pred = net(_t, x, training=False)
            dx = -.5 * (diffusion_weight ** 2) * score_pred

            f = lambda x: net(_t, x, training=False)

            def divergence_fn(_x, _v):
                # Hutchinson’s Estimator
                # computes the divergence of net at x with random vector v
                _, u = jvp(f, (_x,), (_v,))
                return jnp.sum(jnp.dot(u, _v))

            batch_div_fn = jax.vmap(divergence_fn, in_axes=[None, 0])

            def batch_div(x):
                return batch_div_fn(x, v).mean(axis=0)

            grad_div_fn = grad(batch_div)

            dscore_1 = - grad_div_fn(x)
            dscore_2 = - jvp(f, (x,), (score,))[1]  # f(x), df/dx * v = jvp(f, x, v)
            dscore = dscore_1 + dscore_2

            def dcritic_loss_fn(_x):
                critic_pred = critic(_t, _x, training=True)
                loss = ((critic_pred - score_pred) ** 2).sum(axis=(1, 2, 3)).mean()
                return loss

            dc_gv = objax.GradValues(dcritic_loss_fn, critic.vars())
            dcritic_grad, dcritic_loss = dc_gv(x)
            dcritic_loss = dcritic_loss[0][None]
            dstates = [dx, dscore, dcritic_loss, dcritic_grad]

            return dstates

        tspace = np.array((0., 1.))

        result = odeint(ode_func, state_init, tspace, atol=tolerance, rtol=tolerance)

        _g = [_var[1] for _var in result[3]]
        return _g, result[2][1], critic.vars().subset(is_a=StateVar).tensors()

    # define optimizer
    opt = objax.optimizer.Adam(critic.vars())

    # define train_op
    def train_op(x, key):
        g, v, svars_t = critic_gv_fn(x, key)
        critic.vars().subset(is_a=StateVar).assign(svars_t)
        opt(lr, g)
        return v

    train_op = objax.Jit(train_op, critic.vars().subset(is_a=TrainVar) + opt.vars())
    # reports error if I set "train_op = objax.Jit(train_op, critic.vars() + opt.vars())"
AlexeyKurakin commented 3 years ago

One issue I see in the code is mixing pure JAX function transformation (jax.vmap and jvp) with Objax object-oriented approach. Usually this either does not work at all or works incorrectly.

Another issue - why do you need to do manual assignment of StateVars in variable collection (like critic.vars().subset(is_a=StateVar).assign(svars_t))?

So I would suggest to refactor code, so all of the code is using Objax wrappers instead of pure JAX transformations. So jax.vmap should be replaced with objax.Vectorize. We don't have build-in replacement for jvp, but it should be easy to add similar to objax.GradValues

shenzebang commented 3 years ago

Hi Alexey,

Thank you for the information. How should I compute the gradient w.r.t. the input of the module using objax.Grad? E.g. I want to compute the Jacobian vector product $\nabla_x(v \nabla_x f(t, x) v)$. Here f is a module that takes an n dimensional vector and time $t$ as input and output an n dimensional vector. $v$ is a fixed n dimensional vector.

I tried the following code (net is the function f)

        x_vc = objax.VarCollection()
        x_vc["x"] = objax.StateVar(x)

        def _net_v_prod(_v):
            return jnp.dot(net(_t, x_vc["x"], training=False), _v)

        d_net_v_prod = objax.Grad(_net_v_prod, x_vc)

        @objax.Function.with_vars(x_vc + d_net_v_prod.vars())
        def div_net_fn(_v):
            return jnp.dot(d_net_v_prod(_v), _v)

        grad_div_fn = objax.Grad(div_net_fn, x_vc)

        dscore_1 = - grad_div_fn(v)

but I receive this error:

Traceback (most recent call last): File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 450, in critic = train_critic(key) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 239, in train_critic _loss = train_op(data, subkey) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 258, in call output, changes = self._call(self.vc.tensors(), kwargs, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 332, in cache_miss out_flat = xla.xla_call( File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1402, in bind return call_bind(self, fun, *args, *params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 1405, in process return trace.process_call(self, fun, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 600, in process_call return primitive.impl(f, tracers, params) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 576, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun ans = call(fun, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 248, in jit return f(args, kwargs), self.vc.tensors() File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(*args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 221, in train_op g, v = critic_gv_fn(x, key) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 206, in critic_gv_fn result = odeint(ode_func, state_init, tspace, atol=tolerance, rtol=tolerance) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint converted, consts = custom_derivatives.closure_convert(func, y0, t[0], args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/custom_derivatives.py", line 825, in closure_convert return _closure_convert_for_avals(fun, in_tree, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 185, in wrapper return cached(bool(config.x64_enabled), args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 178, in cached return f(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/custom_derivatives.py", line 830, in _closure_convert_for_avals jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1178, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 179, in ode_func dscore_1 = - grad_div_fn(v) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 121, in call return super().call(*args, kwargs)[0] File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 82, in call g, (outputs, changes) = self._call(inputs + self.vc.subset(TrainVar).tensors(), File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 752, in grad_faux (, aux), g = value_and_grad_f(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 811, in value_and_grad_f ans, vjp_py, aux = _vjp(f_partial, dyn_args, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 1882, in _vjp out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, primals, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 58, in f_func outputs = f(list_args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(*args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 167, in div_net_fn return jnp.dot(d_net_v_prod(_v), _v) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 121, in call return super().call(*args, *kwargs)[0] File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 82, in call g, (outputs, changes) = self._call(inputs + self.vc.subset(TrainVar).tensors(), File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 752, in grad_faux (, aux), g = value_and_grad_f(*args, kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 811, in value_and_grad_f ans, vjp_py, aux = _vjp(f_partial, dyn_args, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/api.py", line 1882, in _vjp out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 116, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, primals, has_aux=True) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 498, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/gradient.py", line 58, in f_func outputs = f(*list_args, *kwargs) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/module.py", line 184, in call return self.wrapped(args, kwargs) File "/home/zebang/PycharmProjects/ode_diffusion_jax/main.py", line 161, in _net_v_prod return jnp.dot(net(_t, x_vc["x"], training=False), _v) File "/home/zebang/PycharmProjects/ode_diffusion_jax/model/neural_ode_model.py", line 111, in call h1 = self.conv1(x) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/objax/nn/layers.py", line 198, in call y = lax.conv_general_dilated(x, self.w.value, self.strides, self.padding, File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 625, in conv_general_dilated return conv_general_dilated_p.bind( File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 258, in bind tracers = map(top_trace.full_raise, args) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map return list(map(f, *args)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 365, in full_raise return self.pure(val) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1012, in new_const aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val)) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 927, in get_aval return concrete_aval(x) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 918, in concrete_aval return concrete_aval(x.__jax_array__()) File "/home/zebang/miniconda3/envs/torch_env/lib/python3.8/site-packages/jax/core.py", line 919, in concrete_aval raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " TypeError: Value Traced<ShapedArray(float32[32,1,28,28])>with<DynamicJaxprTrace(level=1/1)> with type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type

Process finished with exit code 1

AlexeyKurakin commented 3 years ago

objax.Grad(div_net_fn, x_vc) is a correct way to compute gradients of div_net_fn with respect to x_vc

The error which you see is cased by something else. Recently we introduced duck typing of Objax variables which allow to use Objax variables in JAX expressions without typing .value. However one of the recent changes in JAX broke it so it does not always work at expected. The error which you are seeing is a result of this regression.

Temporary workaround could be to just add .value every time you access Objax variable. For example jnp.dot(net(_t, x_vc["x"], training=False), _v) would change to jnp.dot(net(_t, x_vc["x"].value, training=False), _v). I assume here that _t and _v are JAX arrays and not Objax variables. If _t and _v are Objax variables then you also have to do _t.value and _v.value.

Another possible workaround - rollback to earlier version of JAX. Specifically JAX version 0.2.10 should work