openai / gym

A toolkit for developing and comparing reinforcement learning algorithms.
https://www.gymlibrary.dev
Other
34.46k stars 8.59k forks source link

[Proposal] Functional API #2954

Open RedTachyon opened 2 years ago

RedTachyon commented 2 years ago

Proposal

This is a fairly loose proposal for a feature that imo could be very useful, but it doesn't have to be done anytime soon.

Currently, gym uses a rather messy stateful OOP approach, which - to be fair - is sometimes necessary (e.g. for Atari or Unity).

The current model is extremely object-oriented - the state is maintained within the gym.Env, and every time we take a step, it is updated. We can get extra information about the environment via a few exposed methods, but crucially, we never really look at the raw markovian state.

My proposal is officially exposing a functional structure, which would likely be similar to how e.g. built-in brax environments are implemented. Instead of an object in a classical sense, the environment would be defined as a collection of functions.

An episode rollout would look something like that:

for episode in range(num_episodes):
    state = env.initial()
    time = 0
    while True:
        action = policy(env.embedding(time, state))
        next_state = env.transition(time, state, action)
        reward = env.reward(time, state, action, next_state)
        img = env.render(time, state, mode=)
        if env.terminal(next_state) or env.timeout(time):
            break

        time += 1
        state = next_state

Note that all the env.something methods are in fact static and pure.

There are several benefits to using this approach. It makes vector environments trivial, and very easily bypass the need to use any python multiprocessing. It would make many things much more transparent and more closely matching to the theoretical POMDP-ish model.

Especially given the plan of using jax/brax environments as a replacement of e.g. box2d, a functional API works so much better with it, and it would be possible to e.g. jit the entire collection loop (whereas it might be problematic with the current API, and we're jitting only parts of the step method)

As for implementation, I have some ideas on conversions back and forth. It might not be possible to fit all environments into a functional API, but e.g. classic_control envs would be extremely simple to convert, given that they already maintain a self.state.

As I stated at the beginning, this is a rather loose suggestion, which I think would be very beneficial overall, but I wonder what the general thoughts people have on this.

Motivation

Objects are out, functions are in (see: Jax, Brax, functorch, numba)

Pitch

Add an alternative API for gym environments, which is functional instead of object-oriented

Alternatives

Making a separate library, perhaps tied to gym in one way or another

Checklist

balisujohn commented 2 years ago

This is kind similar to how the current Jax BlackJack draft is set up, where the jit_step and jit_reset can be used in a purely functional manner, but the standard api step and reset functions set the env state using the results of calls to jit_step and jit_reset. Maybe we could have a secondary functional api where possible? I think it would work for many environments in gym, but I could see there being issues with environments which are wrappers for physical robots for example.

balisujohn commented 2 years ago

my thinking is that jit_reset and jit_step should be public functions offered as part of the "Jittable environment" API alongside conventional step and reset.

RedTachyon commented 2 years ago

Yea, that's the ballpark of what I'm thinking. I would keep it more general than just making it for jax jitting, but it would definitely be one of the possibilities. Then we could rewrite the jax blackjack to be primarily written in the functional API (so users can use the full environment in a jitted context), and then we'd expose an object-oriented API with a simple converter

pseudo-rnd-thoughts commented 2 years ago

In the next day or so, Im planning on release a PR with a jax environment specification that follows your functional API closely. The primary difference is that I don't expose the environment state to the user but there is the public step / reset function that follows the Env version and the hidden jittible environment step / reset functions. There will be a new Jax based vectorise class that uses the jittible step and reset functions.

RedTachyon commented 2 years ago

@pseudo-rnd-thoughts I'm not sure I understand. If it doesn't expose the state, then it's still very different from what I'm proposing. As I understand, the actual step that the user would call is not jittable, and instead it calls a jittable function within it, yes? So it's still the same object API which happens to be powered by a functional environment. I propose having a functional API for those functional environments.

RedTachyon commented 2 years ago

This actually makes me even think that it might be worthwhile to add a consistent behind-the-scenes functional API for the upcoming new envs, and potentially only in the future we would make it more public/applicable to other envs.

pseudo-rnd-thoughts commented 2 years ago

Agreed, I would propose all new gym environments should follow it. While it is possible to achieve similar vectorisation performance with Envpool, it is significantly more complex to do

YouJiacheng commented 2 years ago

If we have state <=> object bijection, then functional API is equivalent to object-oriented API JAX actually support object-oriented code in this way, users only need to implement their envs as a pytree - which means that the state of the environment can be extracted, and can be used to perfectly reconstruct the environment. Since object method in python is just ordinary function without any magic(e.g. this in cpp/java), the line between functional and OO blur.

pseudo-rnd-thoughts commented 2 years ago

That might be true but jax jit acceleration requires the step and reset functions to be pure which in gym is not true currently

YouJiacheng commented 2 years ago

If users jit the entire collection loop instead of step and reset itself, it's okay to have step and reset impure, indeed. (pure function can have impure subroutine) I'd like to have pure function API since it is simple and clear, but it is not necessary for jax jit.

RedTachyon commented 2 years ago

Ultimately I think it's a mix of both things. It's possible to express any computation in an object-oriented way (in fact, it's a thing I pitched a few times as a trivial OOP->functional conversion, where you use the Env object as the state), but using a pure functional API will give us a bit of extra sanity. I can also sorta imagine how an env-pytree hybrid could be jittable, but it does feel like a bit of a hack. So the positive effect on sanity compounds and will be pretty significant overall by adopting a functional API for jax stuff (and potentially other envs too).

jkterry1 commented 1 year ago

Hey, we just launched gymnasium, a fork of Gym by the maintainers of Gym for the past 18 months where all maintenance and improvements will happen moving forward, including this 1.0 roadmap.

We have an announcement post here- https://farama.org/Announcing-The-Farama-Foundation)