RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
613 stars 61 forks source link

[Feature] GymnaxToBrax conversion wrapper #35

Closed DavidSlayback closed 2 years ago

DavidSlayback commented 2 years ago

For after #34

Took a stab at implementing the Brax conversion wrapper and added tests to make sure it's working. A couple notes

  1. Uses a "lookalike" Brax state, the only thing that's a bit different is that state.qp is no longer a QP object, it's whatever the EnvState for a particular environment is
  2. Uses brax's state.info field to maintain and update rng key and env_params so that the step and reset APIs can match Brax
  3. Since action_size and observation_size are properties in Brax, they have to use the default environment parameters
  4. Compatible with all Brax wrappers like VmapWrapper, EvalWrapper, and EpisodeWrapper. Obviously not recommended to use AutoResetWrapper though.
RobertTLange commented 2 years ago

Hi @DavidSlayback, Thank you for all the hard work. Can you make sure that the test suite is passing before I merge the PRs? You will need to from typing import Dict as DictType and then use this for type checking.

DavidSlayback commented 2 years ago

I updated both PRs to make them compatible with python 3.7, and tested everything on my end! I notice that the Seaquest tests still fail and are ignored by the pre-commit workflow...would it be useful for me to try to fix this up?

RobertTLange commented 2 years ago

Heyo, is it possible to make Brax an optional import? It comes with fairly many dependencies and I want to keep the footprint of gymnax low. Maybe using a try/except import? Otherwise you need to add it to the dependencies.

try:
    from brax import envs
except ImportError:
    raise ImportError(
        "You need to install `brax` to use the brax wrapper."
    )
DavidSlayback commented 2 years ago

Done for both the wrapper itself and the test (auto-skips if Brax is not available)

RobertTLange commented 2 years ago

Codecov Report

Merging #35 (b18f699) into main (6682836) will decrease coverage by 0.60%. The diff coverage is 68.24%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #35      +/-   ##
==========================================
- Coverage   84.73%   84.13%   -0.61%     
==========================================
  Files          44       46       +2     
  Lines        2640     2779     +139     
==========================================
+ Hits         2237     2338     +101     
- Misses        403      441      +38     
Impacted Files Coverage Δ
gymnax/environments/conversions/brax.py 0.00% <0.00%> (ø)
gymnax/environments/environment.py 77.35% <ø> (ø)
gymnax/environments/minatar/seaquest.py 33.33% <ø> (ø)
gymnax/registration.py 94.11% <ø> (ø)
gymnax/environments/spaces.py 75.71% <75.00%> (-4.68%) :arrow_down:
gymnax/environments/conversions/gym.py 96.38% <96.38%> (ø)
gymnax/environments/misc/rooms.py 94.33% <0.00%> (+8.49%) :arrow_up:
DavidSlayback commented 2 years ago

Just wanted to check, is there something more you need me to do for this? I'm not really clear on the coverage issues; I supposed I could add some tests for Gymnax -> Gym spaces?

RobertTLange commented 2 years ago

It is merged. Thank you! And sorry for the slow reply.