pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.33k stars 307 forks source link

[BUG] Too much overhead in the environment classes? #1247

Closed skandermoalla closed 8 months ago

skandermoalla commented 1 year ago

Describe the bug

I quickly adapted the benchmark for batched environments to compare against native gymnasium environment classes and got drastically worse performance on CPU.

Single envs and SerialEnvs are up to 50x slower and ParallelEnv is up to 10x slower 😬

Could anyone look into this?

To Reproduce

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

logging.basicConfig(level=logging.ERROR)
logging.captureWarnings(True)
import pandas as pd

pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)
from torch.utils.benchmark import Timer
from torchrl.envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.gym import GymEnv

import gymnasium as gym

N_STEPS = 512
ENV = "CartPole-v1"

def factory():
    return GymEnv(ENV)

def create_single(num_workers):
    env = factory()
    env.rollout(policy=None, max_steps=5)
    env.__is_single = True
    env.__num_workers = num_workers
    return env

def create_single_gym(num_workers):
    env = gym.make(ENV)
    env.__is_single = True
    env.__num_workers = num_workers
    env.reset()
    return env

def create_serial(num_workers):
    env = SerialEnv(num_workers=num_workers, create_env_fn=factory)
    env.rollout(policy=None, max_steps=5)  # Warm-up
    return env

def create_serial_gym(num_workers):
    env = gym.vector.make(ENV, num_envs=num_workers, asynchronous=False)
    env.reset()
    return env

def create_parallel(num_workers):
    env = ParallelEnv(num_workers=num_workers, create_env_fn=factory)
    env.rollout(policy=None, max_steps=5)  # Warm-up
    return env

def create_parallel_gym(num_workers):
    env = gym.vector.make(ENV, num_envs=num_workers, asynchronous=True)
    env.reset()
    return env

def run_env(env):
    if hasattr(env, "__is_single") and env.__is_single:
        n_steps = N_STEPS * env.__num_workers
    else:
        n_steps = N_STEPS
    env.rollout(policy=None, max_steps=n_steps, break_when_any_done=False)

def run_env_gym(env):
    if not hasattr(env, "num_envs"):
        n_steps = N_STEPS * env.__num_workers
    else:
        n_steps = N_STEPS

    observation, info = env.reset()
    for _ in range(n_steps):
        action = env.action_space.sample()  # agent policy that uses the observation and info
        observation, reward, terminated, truncated, info = env.step(action)

        if not hasattr(env, "num_envs") and (terminated or truncated):
            observation, info = env.reset()

if __name__ == "__main__":
    res = {}
    for num_workers in [1, 4, 8]:
        print(f"With num_workers={num_workers}")
        print("-------TorchRL-------")
        print("Single...")
        env_single = create_single(num_workers)
        res_sing = Timer(
            stmt="run_env(env)",
            setup="from __main__ import run_env",
            globals={"env": env_single},
        )
        time_sing = res_sing.blocked_autorange().mean

        print("Serial...")
        env_serial = create_serial(num_workers)
        res_serial = Timer(
            stmt="run_env(env)",
            setup="from __main__ import run_env",
            globals={"env": env_serial},
        )
        time_serial = res_serial.blocked_autorange().mean

        print("Parallel...")
        env_parallel = create_parallel(num_workers)
        res_parallel = Timer(
            stmt="run_env(env)",
            setup="from __main__ import run_env",
            globals={"env": env_parallel},
        )
        time_parallel = res_parallel.blocked_autorange().mean
        print(time_sing, time_serial, time_parallel)

        print("-------Gymnasium-------")
        print("Single...")
        env_single_gym = create_single_gym(num_workers)
        res_sing_gym = Timer(
            stmt="run_env_gym(env)",
            setup="from __main__ import run_env_gym",
            globals={"env": env_single_gym},
        )
        time_sing_gym = res_sing_gym.blocked_autorange().mean

        print("Serial...")
        env_serial_gym = create_serial_gym(num_workers)
        res_serial_gym = Timer(
            stmt="run_env_gym(env)",
            setup="from __main__ import run_env_gym",
            globals={"env": env_serial_gym},
        )
        time_serial_gym = res_serial_gym.blocked_autorange().mean

        print("Parallel...")
        env_parallel_gym = create_parallel_gym(num_workers)
        res_parallel_gym = Timer(
            stmt="run_env_gym(env)",
            setup="from __main__ import run_env_gym",
            globals={"env": env_parallel_gym},
        )
        time_parallel_gym = res_parallel_gym.blocked_autorange().mean
        print(time_sing_gym, time_serial_gym, time_parallel_gym)

        res[f"num_workers_{num_workers}"] = {
            "Single, s": time_sing,
            "Serial, s": time_serial,
            "Parallel, s": time_parallel,
            "Single_gym, s": time_sing_gym,
            "Serial_gym, s": time_serial_gym,
            "Parallel_gym, s": time_parallel_gym,
        }
    df = pd.DataFrame(res).round(3)
    par_sing_time = df.loc["Parallel, s"] / df.loc["Single, s"]
    par_ser_time = df.loc["Parallel, s"] / df.loc["Serial, s"]
    df.loc["relative_time parallel/single, %", :] = (par_sing_time * 100).round(1)
    df.loc["relative_time parallel/serial, %", :] = (par_ser_time * 100).round(1)
    sing_gym_time = df.loc["Single_gym, s"] / df.loc["Single, s"]
    ser_gym_time = df.loc["Serial_gym, s"] / df.loc["Serial, s"]
    par_gym_time = df.loc["Parallel_gym, s"] / df.loc["Parallel, s"]
    df.loc["relative_time single_gym/single, %", :] = (sing_gym_time * 100).round(1)
    df.loc["relative_time serial_gym/serial, %", :] = (ser_gym_time * 100).round(1)
    df.loc["relative_time parallel_gym/parallel, %", :] = (par_gym_time * 100).round(1)
    print(df)
    df.to_csv("batched_benchmark.csv")

output on macOS with an Apple M1.

With num_workers=1
-------TorchRL-------
Single...
[W ParallelNative.cpp:230] Warning: Cannot set number of intraop threads after parallel work has started or after set_num_threads call when using native parallel backend (function set_num_threads)
Serial...
Parallel...
0.16622966699999608 0.23579145900001208 0.258086416999987
-------Gymnasium-------
Single...
Serial...
Parallel...
0.0035619283508775395 0.006423218750002846 0.027073776000001715
With num_workers=4
-------TorchRL-------
Single...
Serial...
Parallel...
0.6535387920000062 0.4890966660000231 0.4010159999999985
-------Gymnasium-------
Single...
Serial...
Parallel...
0.014470199357140652 0.013568580533338566 0.060714156499997785
With num_workers=8
-------TorchRL-------
Single...
Serial...
Parallel...
1.318861792000007 0.8485894579999922 0.6550024590000021
-------Gymnasium-------
Single...
Serial...
Parallel...
0.02902346428572043 0.023225078555561315 0.1061309795000085
                                        num_workers_1  num_workers_4  num_workers_8
Single, s                                       0.166          0.654          1.319
Serial, s                                       0.236          0.489          0.849
Parallel, s                                     0.258          0.401          0.655
Single_gym, s                                   0.004          0.014          0.029
Serial_gym, s                                   0.006          0.014          0.023
Parallel_gym, s                                 0.027          0.061          0.106
relative_time parallel/single, %              155.400         61.300         49.700
relative_time parallel/serial, %              109.300         82.000         77.100
relative_time single_gym/single, %              2.400          2.100          2.200
relative_time serial_gym/serial, %              2.500          2.900          2.700
relative_time parallel_gym/parallel, %         10.500         15.200         16.200

Output on Ubuntu with an Intel(R) Xeon(R) Gold 6240

With num_workers=1
-------TorchRL-------
Single...
Serial...
Parallel...
0.48479735804721713 0.6477549062110484 0.8586247656494379
-------Gymnasium-------
Single...
Serial...
Parallel...
0.008634783618617803 0.019357481082393366 0.06684140597159664
With num_workers=4
-------TorchRL-------
Single...
Serial...
Parallel...
1.8779972507618368 1.224436460994184 1.1977882566861808
-------Gymnasium-------
Single...
Serial...
Parallel...
0.03954721265472472 0.03926779151273271 0.13064894126728177
With num_workers=8
-------TorchRL-------
Single...
Serial...
Parallel...
3.652070662006736 2.082924742717296 1.8401135145686567
-------Gymnasium-------
Single...
Serial...
Parallel...
0.08220831056435902 0.07548716928188999 0.20778966415673494
                                        num_workers_1  num_workers_4  num_workers_8
Single, s                                       0.485          1.878          3.652
Serial, s                                       0.648          1.224          2.083
Parallel, s                                     0.859          1.198          1.840
Single_gym, s                                   0.009          0.040          0.082
Serial_gym, s                                   0.019          0.039          0.075
Parallel_gym, s                                 0.067          0.131          0.208
relative_time parallel/single, %              177.100         63.800         50.400
relative_time parallel/serial, %              132.600         97.900         88.300
relative_time single_gym/single, %              1.900          2.100          2.200
relative_time serial_gym/serial, %              2.900          3.200          3.600
relative_time parallel_gym/parallel, %          7.800         10.900         11.300

System info

Describe the characteristic of your environment:

macOS with M1

import torch, torchrl, numpy, sys
print(torch.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.0.1 0.1.1+3fb2d0d 1.24.3 3.10.11 | packaged by conda-forge | (main, May 10 2023, 19:01:19) [Clang 14.0.6 ] darwin

Ubuntu with Intel(R) Xeon(R) Gold 6240

import torch, torchrl, numpy, sys
print(torch.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.0.1 0.1.1+3fb2d0d 1.24.3 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:39:03) [GCC 11.3.0] linux

Checklist

vmoens commented 1 year ago

Thanks, that's very thorough.

We could definitely make a better job on CPU. Things should get better on cuda, let me try to reproduce.

If it's of interest, you can also try with the Multi(a)SyncDataCollector which is comparatively faster compared with ParallelEnv. ParallelEnv is especially useful when you have large models that you want to execute over a batch of envs, but for simple envs and simple models dispatching each on a separate collector usually works faster.

I'll keep you posted!

skandermoalla commented 1 year ago

Yes, it's a fair point to mention that this doesn't include the overhead of moving data between CPU and GPU, maybe the TorchRL classes are better suited for that. I'll wait for the benchmark on GPU!

vmoens commented 1 year ago

There is certainly room for improvement, but there will always be the cost of building tensordicts out of the gym envs. It's part of the price to pay to have envs that you can recycle across gym, dm_control, brax and many others, with or without rendering, with or without transforms: formatting the data to a common structure comes at a certain cost.

We're quite competitive when it comes to executing more demanding environments, e.g. the ones that require rendering, and when there are transforms to be executed.

For the simplest pendulum or cartpole it's hard to do better than gym...

vmoens commented 1 year ago

There's some more improvement to be achieved but I managed to reduce the compute time by 1/3-1/2 approx for CartPole rollouts with the 2 PRs linked above. Before (on 1, 4 env)

0.6000917710000007 0.8977178509999995 0.9628986140000002
2.3650796359999973 2.1196543889999973 1.398552716999987

After

0.3220986410000002 0.5811318410000013 0.567181644999998
1.4485275350000038 1.648879229000002 0.9843755919999992

We're far from the 10x improvement we could have wished for but it is some progress already!

On GPU, the perf is a bit less impressive. Here's some benchmark on GPU with Pong-v5: Before:

1.1242192424833775 1.490046987310052 1.7628506254404783
4.569527314975858 3.7096991054713726 2.416346298530698

After:

0.8273551240563393 1.2029384840279818 1.3915523868054152
3.3085566088557243 3.4223811104893684 2.093131694942713

From what I can see, bottlenecks are mostly tensordict related (cloning, set, get, checking values etc).

Be assured that we'll work on making these more efficient! As usual: any help is welcome :)

Side note:

In the paper the speed results (still to copy here) are achieved via MultiaSyncDataCollector, not ParallelEnv. For most off-policy algos this is a valid way of collecting data. For on-policy, MultiSync can be used.

vmoens commented 1 year ago

After some more improvements, CartPole looks like this:

0.28064374699999917 0.5035237049999992 0.7829816540000003
1.275079895999994 1.4358267960000006 0.8376919339999986
2.4235935930000068 2.9860773900000055 1.2749536729999988

On my machine, this compares with gym like this:

0.029683207571428078 0.0680098183333347 0.11986857750000013
0.05640612449999871 0.0676216029999992 0.20421299000000204
0.11158799900000815 0.10956462599999384 0.5347069550000043

So the parallel version with 8 workers is "only" twice as slow.

Have a look at the benchmarks too

https://pytorch-labs.github.io/tensordict/dev/bench/ https://pytorch.org/rl/dev/bench/

skandermoalla commented 1 year ago

I'm impressed by the responsiveness! Thanks a lot!

I also reran the benchmark on my side on an M1 chip at the commit @ea6f872 and got the results below.

The major points I found were:

This progress is already awesome, so feel free to close the issue!

                                        num_workers_1  num_workers_4  num_workers_8
Single, s                                       0.055          0.223          0.472
Serial, s                                       0.104          0.281          0.676
Parallel, s                                     0.113          0.218          0.399
MultiSync, s                                    0.075          0.079          0.165
Single_gym, s                                   0.004          0.015          0.031
Serial_gym, s                                   0.007          0.016          0.024
Parallel_gym, s                                 0.025          0.073          0.103
relative_time parallel/single, %              205.500         97.800         84.500
relative_time parallel/serial, %              108.700         77.600         59.000
relative_time multisync/parallel, %            66.400         36.200         41.400
relative_time single_gym/single, %              7.300          6.700          6.600
relative_time serial_gym/serial, %              6.700          5.700          3.600
relative_time parallel_gym/parallel, %         22.100         33.500         25.800
vmoens commented 1 year ago

There's still room for improvement so I'll leave it open as a reminder that we should aim at getting the same speed as gym even for simple envs! Glad to read that you observed the same improvements

vmoens commented 8 months ago

There have been multiple iterations on this so I'm closing the issue for now. If we observe a regression we can reopen