Open edwhu opened 3 months ago
I unfortunately will not be able to take a deep look at this today, but I created a cross-q feature branch. At a glance, have you tried passing the training flag as a static argument in jax.jit (using static_argnames or static_argnums)? It's definitely not happy about the branching if statement in the BatchRenorm
module.
Aplogies for the super late response; I'm just finishing my month-long general exam. Have their been any updates on this?
Hi, I also have been busy with other things. Jax batch normalization seems non-trivial to do correctly, so I think I may just give up on this for now. Still, I think Cross-Q might be an interesting route to explore for further improvements, since it only requires 2 Q networks.
I'm interested in implementing the CrossQ critic update, which only requires two Q networks, and no target networks. This could speed up TDMPC2 a decent amount. A key part of the method is using BatchNorm correctly.
I'm rather new to Flax, and implementing BatchNorm is a bit annoying. We have to carry the batch statistics as part of the training state. I made some preliminary progress towards this, but got a bit stuck with the batch normalization details.
A few requests and questions: 1) Could you make a cross-q branch in the official repo so I can make a PR? We can continue discussion from there. I don't think the CrossQ ever needs to be merged into main, but having it as a variant could be useful to others.
2) Onto the actual bug: First, I made a batch normalized version of the Q function in
mlp.py
that can take in atraining
boolean to specify train or evaluation mode.However, it seems like once that Q function is initialized and traced, when I later try to pass in the boolean, I get an error.
So I get this error:
So currently to just get the code to run, I don't pass in the training boolean, so it is always running
training=True
by default. But in some cases, I think I wanttraining=False
like when I'm using the Q function for planning. Either way, it seems like attempting to pass intraining=...
as an argument throws the error, which suggests to me the Jax tracing isn't recognizingtraining
as a dynamic argument for some reason. You can see the hack / hotfix here: https://github.com/edwhu/tdmpc2-jax/blob/495d09657b64b1d82298eb184a555921bd9e383e/tdmpc2_jax/world_model.py#L291Would love to hear your thoughts on this, I've been stuck on this for a couple hours over the past few days. Thanks!