Closed edwhu closed 1 year ago
Yes, that's the reason
Thanks. It seems straightforward to replace the if
statements with jax.lax.select
or jax.lax.cond
logic and move the policy calls into the branches - any particular reason why you chose to use python if
statements instead?
A cond has runtime overhead (and also wouldn't work with the current API because JAX/GPUs don't support string types: https://github.com/google/jax/issues/3045).
I am interested in adding more behaviors.
In the
policy
function ofagent.py
, all behaviors are called before the mode is even checked. https://github.com/danijar/dreamerv3/blob/423291a9875bb9af43b6db7150aaa972ba889266/dreamerv3/agent.py#L51-L60 This doesn't seem like it will scale if we add a lot of different behaviors.Is this because we want Jax to trace all variables in the train function, i.e. in the
JaxAgent
where we get the initial variables of the train function for optimization? https://github.com/danijar/dreamerv3/blob/423291a9875bb9af43b6db7150aaa972ba889266/dreamerv3/jaxagent.py#L228