Closed sai-prasanna closed 4 months ago
Embodied (embodied.Agent
) requires agents to receive and return Numpy arrays. The Dreamer agent dreamerv3.Agent
receives and returns JAX arrays, so we wrap it into JAXAgent
that takes care of all JAX-related logic and exposes a Numpy API to the outside.
Hi, I am confused why JaxAgent is inheriting from the embodied.Agent class. Isn't embodied.Agent actually wrapped by JaxAgent?
https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/jaxagent.py#L24
And is there any reason why you have used decorators? The function signatures are changed and new methods are present in the JaxAgent class, and the JaxAgent actually creates an agent "self.agent = agent_cls()" object. Would it be cleaner to simply create the JaxAgent explicitly? Or was it meant to work with other "Agent" classes as well?