EdanToledo / Stoix

🏛️A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX • End-to-End JAX RL
Apache License 2.0
197 stars 16 forks source link

[FEATURE] integrating envpool #50

Closed renos closed 5 months ago

renos commented 5 months ago

Please describe the purpose of the feature. Is it related to a problem?

A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

The current implementation using envpool on clearnrl is buggy and doesn't use pmap, having envpool integrating with Stoix would be useful.

Describe the solution you'd like

A clear and concise description of what you want to happen.

Implementt envpool

Describe alternatives you've considered

A clear and concise description of any alternative solutions or features you've considered.

How do we know when implementation of this feature is complete?

Checklist:

Additional context

Add any other context about the feature request here.

EdanToledo commented 5 months ago

Hey, I was planning on doing this at some point. I know envpool can be jitted but I'm not fully aware of its pmap capabilities but I will get around to this at some point soon.

EdanToledo commented 5 months ago

Hello, so I was playing around with envpool and unfortunately it seems to have some limitations with regards to its xla compatibility. For example, although it can be jitted, it seems unable to vmap due to not having xla batching primitives and consequently I imagine that means there are issues for pmapping (at least in the current form of how stoix does it). There are obviously solutions to this, for example running the environment loop separately and collecting data then splitting it into the correct shapes to then pmap only the SGD steps however that semi-defeats the purpose and general design paradigm of stoix where both environment data collection and SGD steps happen on all devices. If you give it a go then I'd be more than happy to review it and then subsequently adapt it for Stoix's use case.

renos commented 5 months ago

@EdanToledo I suppose that makes sense. My use case would be for for training atari models, but I think end-to-end jax rl should be enough considering the limiting factor seems to be cpu-gpu communication anyways since envpool still runs the environment itself on cpu.