ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.29k stars 5.63k forks source link

[RLib] AssertionError using Simplex with default concentration #45804

Open ema-pe opened 3 months ago

ema-pe commented 3 months ago

What happened + What you expected to happen

Hi, I'm using the simplex space defined in RRLib as an action space for an environment. I want an action space that contains a single point with three coordinates whose values are between [0,1] and sum to 1. The Simplex constructor takes shape and concentration parameters:

The concentration is the default as calculated in the constructor. The problem is that I cannot initialize this environment because it throws an AssertionError exception from the Simplex constructor.

I think the problem is in simplex.py, in the __init__ function. Why is there an assertion concentration.shape == shape[:-1]? Why is there [:-1] in the assertion? For shape=(1,3) and concentration=np.array([1,1,1]):

And this throws the exception, but it should not.

https://github.com/ray-project/ray/blob/0b3943b3175bdc28af39c9f249069fa94c2a151b/rllib/utils/spaces/simplex.py#L8-L36

Versions / Dependencies

Reproduction script

Run the following test script called main_ray and it will crash with an AssertionError exception.

import numpy as np

import gymnasium as gym

from ray.rllib.utils.spaces.simplex import Simplex

class SimplexTestEnv(gym.Env):
    def __init__(self, env_config):
        self.action_space = Simplex(shape=(1,3), concentration=np.array([1, 1, 1]))
        self.observation_space = gym.spaces.Box(shape=(1,), low=-1, high=1)

    def reset(self):
        return np.zeros(1)

    def step(self, action):
        return np.zeros(1), 0, False, False, info

def main():
    env = SimplexTestEnv({})

    result = env.step(np.zeros(1))

if __name__ == "__main__":
    main()

The program crashes with the following stack trace:

(.env2.23) emanuele@fedora-t4:~/ray/test-simplex$ python main_ray.py 
Traceback (most recent call last):
  File "/home/emanuele/ray/test-simplex/main_ray.py", line 25, in <module>
    main()
  File "/home/emanuele/ray/test-simplex/main_ray.py", line 20, in main
    env = SimplexTestEnv({})
  File "/home/emanuele/ray/test-simplex/main_ray.py", line 9, in __init__
    self.action_space = Simplex(shape=(1,3), concentration=np.array([1, 1, 1]))
  File "/home/emanuele/ray/.env2.23/lib64/python3.10/site-packages/ray/rllib/utils/spaces/simplex.py", line 32, in __init__
    concentration.shape == shape[:-1]
AssertionError: (3,) vs (1,)
(.env2.23) emanuele@fedora-t4:~/ray/test-simplex$

Issue Severity

High: It blocks me from completing my task.

simonsays1980 commented 23 hours ago

@ema-pe Thanks for raising this issue and sorry that you bumped into it. I opened a PR that should fix it. Waiting for review by my colleague.