RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

Potential bug in Tuple space #53

Closed joefarrington closed 1 year ago

joefarrington commented 1 year ago

Hi Rob,

Great library, thanks for all the hard work. Have been using some custom Gymnax environements in recent work (e.g.

There seem to be issues with the sample() and contains() methods for the Tuple space.

For example:

import gymnax.environments.spaces as spaces
s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])

TypeError                                 Traceback (most recent call last)
Cell In[95], line 3
      1 import gymnax.environments.spaces as spaces
      2 s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])
----> 3 s.sample(rng=jax.random.PRNGKey(0))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/](, in Tuple.sample(self, rng)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/](, in (.0)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

TypeError: list indices must be integers or slices, not Discrete
s.contains((1, 1))
TypeError                                 Traceback (most recent call last)
Cell In[96], line 1
----> 1 s.contains((1, 1))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/](, in Tuple.contains(self, x)
    127 out_of_space = 0
    128 for space in self.spaces:
--> 129     out_of_space += 1 - space.contains(x)
    130 return out_of_space == 0

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/](, in Discrete.contains(self, x)
     41 """Check whether specific object is within space."""
     42 # type_cond = isinstance(x, self.dtype)
     43 # shape_cond = (x.shape == self.shape)
---> 44 range_cond = jnp.logical_and(x >= 0, x < self.n)
     45 return range_cond

TypeError: '>=' not supported between instances of 'tuple' and 'int'

I think the following fixes it, happy to raise a pull request if you want.

class Tuple(Space):
    """Minimal jittable class for tuple (product) of jittable spaces."""

    def __init__(self, spaces: Sequence[Space]):
        self.spaces = spaces
        self.num_spaces = len(spaces)

    def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]:
        """Sample random action from all subspaces."""
        key_split = jax.random.split(rng, self.num_spaces)
        return tuple(
            [s.sample(key_split[i]) for i, s in enumerate(self.spaces)]

    def contains(self, x: jnp.int_) -> bool:
        """Check whether dimensions of object are within subspace."""
        # type_cond = isinstance(x, tuple)
        # num_space_cond = len(x) != len(self.spaces)
        # Check for each space individually
        out_of_space = 0
        for i,space in enumerate(self.spaces):
            out_of_space += 1 - space.contains(x[i])
        return out_of_space == 0
RobertTLange commented 1 year ago

Hi @joefarrington -- that is so awesome! Thank you for sharing your publication. Parallelization of classic dynamic programming techniques seems like a great application for JAX/gymnax.

And thank you for sharing the bug/fix. I quickly added it in 4c0628ee0656bc9a76ddaec9a42a69db6a04ab4a since I want to make a release fairly soon. But if there are any other bugs / recommendations and you feel like opening PRs, I would be happy to review/integrate them.

Have a good Sunday, Rob