ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
43 stars 6 forks source link

Help with CrossQ implementation #8

Open edwhu opened 3 months ago

edwhu commented 3 months ago

I'm interested in implementing the CrossQ critic update, which only requires two Q networks, and no target networks. This could speed up TDMPC2 a decent amount. A key part of the method is using BatchNorm correctly.

I'm rather new to Flax, and implementing BatchNorm is a bit annoying. We have to carry the batch statistics as part of the training state. I made some preliminary progress towards this, but got a bit stuck with the batch normalization details.

A few requests and questions: 1) Could you make a cross-q branch in the official repo so I can make a PR? We can continue discussion from there. I don't think the CrossQ ever needs to be merged into main, but having it as a variant could be useful to others.

2) Onto the actual bug: First, I made a batch normalized version of the Q function in mlp.py that can take in a training boolean to specify train or evaluation mode.

However, it seems like once that Q function is initialized and traced, when I later try to pass in the boolean, I get an error.

  @jax.jit
  def Q(self, z: jax.Array, a: jax.Array, params: Dict, key: PRNGKeyArray, train: bool
        ) -> Tuple[jax.Array, jax.Array]:
    z = jnp.concatenate([z, a], axis=-1)
    # TODO: figure out why including train as an argument breaks things.
    # logits, updates = self.value_model.apply_fn(
    #     {'params': params, 'batch_stats': self.value_model.batch_stats}, z, rngs={'dropout': key}, mutable=['batch_stats'])
    logits, updates = self.value_model.apply_fn(
    {'params': params, 'batch_stats': self.value_model.batch_stats}, z, train, rngs={'dropout': key}, mutable=['batch_stats'])

    Q = two_hot_inv(logits, self.symlog_min, self.symlog_max, self.num_bins)
    return Q, logits, updates

So I get this error:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/train.py", line 217, in train
    agent, train_info = agent.update(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/tdmpc2.py", line 318, in update
    (encoder_grads, dynamics_grads, value_grads, reward_grads, continue_grads), model_info = jax.grad(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/tdmpc2.py", line 264, in world_model_loss_fn
    all_q, all_q_logits, updates = self.model.Q(jnp.concat([zs[:-1], next_z]), jnp.concat([actions, next_action]), value_params, key=Q_key, train=True)
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/world_model.py", line 291, in Q
    logits, updates = self.value_model.apply_fn(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/networks/ensemble.py", line 19, in __call__
    return ensemble()(*args, **kwargs)
  File "/home/edward/miniconda3/envs/tdmpc2jax/lib/python3.10/site-packages/flax/linen/combinators.py", line 105, in __call__
    outputs = self.layers[0](*args, **kwargs)
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/networks/mlp.py", line 215, in __call__
    x = self.norm(dtype=self.dtype, use_running_average = not train)(x)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function Q at /home/edward/projects/tdmpc2-jax/tdmpc2_jax/world_model.py:286 for jit. This concrete value was not available in Python because it depends on the value of the argument train.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
  0%|                                                                    

So currently to just get the code to run, I don't pass in the training boolean, so it is always running training=True by default. But in some cases, I think I want training=False like when I'm using the Q function for planning. Either way, it seems like attempting to pass in training=... as an argument throws the error, which suggests to me the Jax tracing isn't recognizing training as a dynamic argument for some reason. You can see the hack / hotfix here: https://github.com/edwhu/tdmpc2-jax/blob/495d09657b64b1d82298eb184a555921bd9e383e/tdmpc2_jax/world_model.py#L291

Would love to hear your thoughts on this, I've been stuck on this for a couple hours over the past few days. Thanks!

ShaneFlandermeyer commented 3 months ago

I unfortunately will not be able to take a deep look at this today, but I created a cross-q feature branch. At a glance, have you tried passing the training flag as a static argument in jax.jit (using static_argnames or static_argnums)? It's definitely not happy about the branching if statement in the BatchRenorm module.

ShaneFlandermeyer commented 2 months ago

Aplogies for the super late response; I'm just finishing my month-long general exam. Have their been any updates on this?

edwhu commented 2 months ago

Hi, I also have been busy with other things. Jax batch normalization seems non-trivial to do correctly, so I think I may just give up on this for now. Still, I think Cross-Q might be an interesting route to explore for further improvements, since it only requires 2 Q networks.