danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 218 forks source link

Class hierarchy JaxAgent and Agent #104

Closed sai-prasanna closed 4 months ago

sai-prasanna commented 8 months ago

Hi, I am confused why JaxAgent is inheriting from the embodied.Agent class. Isn't embodied.Agent actually wrapped by JaxAgent?

https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/jaxagent.py#L24

And is there any reason why you have used decorators? The function signatures are changed and new methods are present in the JaxAgent class, and the JaxAgent actually creates an agent "self.agent = agent_cls()" object. Would it be cleaner to simply create the JaxAgent explicitly? Or was it meant to work with other "Agent" classes as well?

danijar commented 4 months ago

Embodied (embodied.Agent) requires agents to receive and return Numpy arrays. The Dreamer agent dreamerv3.Agent receives and returns JAX arrays, so we wrap it into JAXAgent that takes care of all JAX-related logic and exposes a Numpy API to the outside.