RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
585 stars 54 forks source link

Notebook links missing? #23

Closed kenjyoung closed 1 year ago

kenjyoung commented 2 years ago

Thanks so much for releasing this repo, it looks great!

The top two links in the examples section of the README give 404 errors for me, e.g.:

The 'RobertTLange/gymnax' repository doesn't contain the 'notebooks/getting_started.ipynb' path in 'main'. `

RobertTLange commented 2 years ago

Thank you for the kind words and raising this. The notebook links should be addressed by PR #25. Your work obviously has inspired parts of this package. I recently saw that you also have an Asterix implementation in your MCTS demo, which looks great.

On a different note: I am still struggling with translating Seaquest due to the more involved logic and the challenges of working around nested if-else statements and potentially many new enemies/objects, which are stored for numpy MinAtar in unbounded lists. This doesn't play well with JAX's requirements for static shapes. For Asterix I got to work around this using a maximal number of possible entities together with a counter and fixing the shapes accordingly. For Seaquest this appears less elegant :) Let me know if you have a different idea for how to work around this or generally if you want to have a chat. Cheers, Rob

kenjyoung commented 2 years ago

I can definitely see how the Seaquest implementation would be uniquely annoying to port to Jax. I can't think of a better idea than bounding the number of each entity type. This may be a little inelegant, but I think we can still bound the number of possible entities in the original implementation so doing so, with a sufficiently high bound, shouldn't impact the game dynamics.

Everything moves across the screen, and enemies spawn at a deterministic rate so one can bound the total number of enemies on screen by the time it takes each one to cross multiplied by the number of new enemies that spawn in that time (similar for divers and bullets). For enemies, I believe you can bound the number on screen by: 10*move_speed/e_spawn_speed This changes as difficulty ramps but a quick calculation says it is no more than 20 when difficulty is highest, which isn't too bad.

Aside from that, the game logic is probably annoying but not impossible to implement in Jax. In hindsight, I could have probably made the logic much simpler without changing the game much. I can try to have a more detailed look at it sometime in the next couple of weeks :)