danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.36k stars 231 forks source link

Question about `policy(...)` implementation. #63

Closed edwhu closed 1 year ago

edwhu commented 1 year ago

I am interested in adding more behaviors.

In the policy function of agent.py, all behaviors are called before the mode is even checked. https://github.com/danijar/dreamerv3/blob/423291a9875bb9af43b6db7150aaa972ba889266/dreamerv3/agent.py#L51-L60 This doesn't seem like it will scale if we add a lot of different behaviors.

Is this because we want Jax to trace all variables in the train function, i.e. in the JaxAgent where we get the initial variables of the train function for optimization? https://github.com/danijar/dreamerv3/blob/423291a9875bb9af43b6db7150aaa972ba889266/dreamerv3/jaxagent.py#L228

danijar commented 1 year ago

Yes, that's the reason

edwhu commented 1 year ago

Thanks. It seems straightforward to replace the if statements with jax.lax.select or jax.lax.cond logic and move the policy calls into the branches - any particular reason why you chose to use python if statements instead?

danijar commented 1 year ago

A cond has runtime overhead (and also wouldn't work with the current API because JAX/GPUs don't support string types: https://github.com/google/jax/issues/3045).