ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.27k stars 5.81k forks source link

RLLib issue with making the program deterministic #27292

Closed utkarshp closed 2 years ago

utkarshp commented 2 years ago

What happened + What you expected to happen

I am trying to make my code deterministic. I have tried setting different seeds that I could think of to a fixed value, but I don't seem to be able to make it work. I have narrowed the issue down to this: I think Ray/RLLib is using the global random state in a non-deterministic way somewhere during every call to train(). I was able to reproduce the issue in the attached script. Note that this is not my code, but copied from test_external_env.py, with edits to generate random numbers.

This has been particularly frustrating to me when trying to debug my own code which seems to crash or give unexpected outputs in some runs, but runs fine when I try to debug.

Here is the output I see from the attached code:

❯ python src/test_random.py
Starting tests
2022-07-29 18:25:04,690 INFO services.py:1338 -- View the Ray dashboard at http://127.0.0.1:8265
Starting tests
framework=tf
2022-07-29 18:25:28,200 INFO trainer.py:722 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also want to then set `eager_tracing=True` in order to reach similar execution speed as with static-graph mode.
2022-07-29 18:25:28,200 INFO dqn.py:141 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
2022-07-29 18:25:28,200 INFO trainer.py:743 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2022-07-29 18:25:45,027 WARNING deprecation.py:45 -- DeprecationWarning: `SampleBatch['is_training']` has been deprecated. Use `SampleBatch.is_training` instead. This will raise an error in the future!
2022-07-29 18:25:45,463 INFO trainable.py:124 -- Trainable.setup took 17.264 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
Iteration 0, reward 17.50877192982456, timesteps 1000, rnum 2191835438
Iteration 1, reward 13.05, timesteps 2000, rnum 3479080518
Iteration 2, reward 15.61, timesteps 3000, rnum 2426226194
Iteration 3, reward 21.28, timesteps 4000, rnum 2342326304
Iteration 4, reward 27.15, timesteps 5000, rnum 3122304301
Iteration 5, reward 36.52, timesteps 6000, rnum 910541280
Iteration 6, reward 44.63, timesteps 7000, rnum 910541280
Iteration 7, reward 53.59, timesteps 8000, rnum 3327464432
Iteration 8, reward 61.15, timesteps 9000, rnum 1644020759
Iteration 9, reward 68.07, timesteps 10000, rnum 910541280
Iteration 10, reward 75.88, timesteps 11000, rnum 910541280
Iteration 11, reward 84.43, timesteps 12000, rnum 880162936
framework=torch
2022-07-29 18:26:25,133 WARNING deprecation.py:45 -- DeprecationWarning: `convert_to_non_torch_type` has been deprecated. Use `ray/rllib/utils/numpy.py::convert_to_numpy` instead. This will raise an error in the future!
Iteration 0, reward 10.826086956521738, timesteps 1000, rnum 1581442016
Iteration 1, reward 11.99, timesteps 2000, rnum 1280829708
Iteration 2, reward 15.25, timesteps 3000, rnum 1012006804
Iteration 3, reward 22.11, timesteps 4000, rnum 3430153998
Iteration 4, reward 31.44, timesteps 5000, rnum 1664651095
Iteration 5, reward 39.48, timesteps 6000, rnum 1286226248
Iteration 6, reward 48.45, timesteps 7000, rnum 910541280
Iteration 7, reward 57.75, timesteps 8000, rnum 3327464432
Iteration 8, reward 66.63, timesteps 9000, rnum 2750021556
Iteration 9, reward 75.1, timesteps 10000, rnum 910541280
Iteration 10, reward 84.41, timesteps 11000, rnum 3327464432

As you can see, each of the random numbers are different. If the code was deterministic, we would be getting same random numbers in every iteration. Note that even a constant number of calls to random.random within the train function would still guarantee that the rnums are all same. This behavior indicates presence of a non-deterministic number of calls to random.random()

Versions / Dependencies

Output of conda list:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             4.5                       1_gnu
absl-py                   0.13.0           py38h06a4308_0
aiohttp                   3.7.4.post0              pypi_0    pypi
aiohttp-cors              0.7.0                    pypi_0    pypi
aioredis                  1.3.1                    pypi_0    pypi
aiosignal                 1.2.0                    pypi_0    pypi
alsa-lib                  1.2.3                h516909a_0    conda-forge
astunparse                1.6.3                    pypi_0    pypi
async-timeout             3.0.1                    pypi_0    pypi
atk-1.0                   2.36.0               h28cd5cc_0
attrs                     21.2.0             pyhd3eb1b0_0
blas                      1.0                         mkl
blessed                   1.19.1                   pypi_0    pypi
blessings                 1.7                      pypi_0    pypi
blinker                   1.4              py38h06a4308_0
bottleneck                1.3.2            py38heb32a55_1
brotlipy                  0.7.0           py38h27cfd23_1003
bzip2                     1.0.8                h7b6447c_0
c-ares                    1.17.1               h27cfd23_0
ca-certificates           2022.4.26            h06a4308_0
cachetools                4.2.2              pyhd3eb1b0_0
cairo                     1.16.0               hf32fb01_1
certifi                   2022.6.15        py38h06a4308_0
cffi                      1.14.6           py38h400218f_0
chardet                   4.0.0                    pypi_0    pypi
charset-normalizer        2.0.4              pyhd3eb1b0_0
chrpath                   0.16              h7f98852_1002    conda-forge
clang                     5.0                      pypi_0    pypi
click                     8.0.1              pyhd3eb1b0_0
cloudpickle               1.6.0              pyhd3eb1b0_0
colorful                  0.5.4                    pypi_0    pypi
coverage                  5.5              py38h27cfd23_2
cryptography              3.4.7            py38hd23ed53_0
cudatoolkit               11.3.1               h2bc3f7f_2
cycler                    0.10.0                   pypi_0    pypi
cython                    0.29                     pypi_0    pypi
dataclasses               0.8                pyh6d0b6a4_7
dbus                      1.13.18              hb2f20db_0
distlib                   0.3.4                    pypi_0    pypi
dm-tree                   0.1.6                    pypi_0    pypi
expat                     2.4.1                h2531618_2
ffmpeg                    4.2.2                h20bf706_0
filelock                  3.7.1                    pypi_0    pypi
flatbuffers               1.12                     pypi_0    pypi
font-ttf-dejavu-sans-mono 2.37                 hd3eb1b0_0
font-ttf-inconsolata      2.001                hcb22688_0
font-ttf-source-code-pro  2.030                hd3eb1b0_0
font-ttf-ubuntu           0.83                 h8b1ccd4_0
fontconfig                2.13.1               h6c09931_0
fonts-anaconda            1                    h8fa9717_0
fonts-conda-ecosystem     1                    hd3eb1b0_0
freetype                  2.10.4               h5ab3b9f_0
fribidi                   1.0.10               h7b6447c_0
frozenlist                1.3.0                    pypi_0    pypi
future                    0.18.2                   py38_1
gast                      0.4.0                    pypi_0    pypi
gdk-pixbuf                2.42.6               h04a7f16_0    conda-forge
gettext                   0.21.0               hf68c758_0
giflib                    5.2.1                h7b6447c_0
glib                      2.68.2               h36276a3_0
gmp                       6.2.1                h2531618_2
gnuplot                   5.4.1                hec6539f_2    conda-forge
gnutls                    3.6.15               he1e5248_0
gobject-introspection     1.68.0           py38h2109141_1
google-api-core           1.31.2                   pypi_0    pypi
google-auth               1.35.0                   pypi_0    pypi
google-auth-oauthlib      0.4.6                    pypi_0    pypi
google-pasta              0.2.0                    pypi_0    pypi
googleapis-common-protos  1.53.0                   pypi_0    pypi
gprof2dot                 2021.2.21                pypi_0    pypi
gpustat                   1.0.0b1                  pypi_0    pypi
gputil                    1.4.0              pyh9f0ad1d_0    conda-forge
graphite2                 1.3.14               h23475e2_0
grpcio                    1.40.0                   pypi_0    pypi
gst-plugins-base          1.18.4               hf529b03_2    conda-forge
gstreamer                 1.18.4               h76c114f_2    conda-forge
gtk2                      2.24.33              h539f30e_1    conda-forge
gym                       0.19.0           py38he5a9106_0    conda-forge
h5py                      3.1.0                    pypi_0    pypi
harfbuzz                  2.8.1                h83ec7ef_0    conda-forge
hiredis                   2.0.0                    pypi_0    pypi
icu                       68.1                 h2531618_0
idna                      3.2                pyhd3eb1b0_0
imageio                   2.9.0                    pypi_0    pypi
importlib-metadata        4.8.1            py38h06a4308_0
iniconfig                 1.1.1                    pypi_0    pypi
intel-openmp              2021.3.0          h06a4308_3350
joblib                    1.1.0              pyhd3eb1b0_0
jpeg                      9d                   h7f8727e_0
jsonschema                3.2.0                    pypi_0    pypi
keras                     2.6.0                    pypi_0    pypi
keras-preprocessing       1.1.2                    pypi_0    pypi
kiwisolver                1.3.2                    pypi_0    pypi
krb5                      1.19.2               hac12032_0
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.35.1               h7274673_9
libclang                  11.1.0          default_ha53f305_1    conda-forge
libedit                   3.1.20210714         h7f8727e_0
libevent                  2.1.10               hcdb4288_3    conda-forge
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libgd                     2.3.2                h78a0170_0    conda-forge
libgfortran-ng            7.5.0               ha8ba4b0_17
libgfortran4              7.5.0               ha8ba4b0_17
libglib                   2.68.2               h3e27bee_0    conda-forge
libgomp                   9.3.0               h5101ec6_17
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.2                h7f8727e_0
libllvm11                 11.1.0               h3826bc1_0
libogg                    1.3.5                h27cfd23_1
libopus                   1.3.1                h7b6447c_0
libpng                    1.6.37               hbc83047_0
libpq                     13.3                 hd57d9b9_0    conda-forge
libprotobuf               3.17.2               h4ff587b_1
libstdcxx-ng              9.3.0               hd4cf53a_17
libtasn1                  4.16.0               h27cfd23_0
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0
libuuid                   1.0.3                h1bed415_2
libuv                     1.40.0               h7b6447c_0
libvorbis                 1.3.7                h7b6447c_0
libvpx                    1.7.0                h439df22_0
libwebp                   1.2.0                h89dd481_0
libwebp-base              1.2.0                h27cfd23_0
libxcb                    1.14                 h7b6447c_0
libxkbcommon              1.0.3                he3ba5ed_0    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
lz4                       3.1.3                    pypi_0    pypi
lz4-c                     1.9.3                h295c915_1
markdown                  3.3.4            py38h06a4308_0
matplotlib                3.4.2                    pypi_0    pypi
mkl                       2021.3.0           h06a4308_520
mkl-service               2.4.0            py38h7f8727e_0
mkl_fft                   1.3.0            py38h42c9631_2
mkl_random                1.2.2            py38h51133e4_0
msgpack                   1.0.2                    pypi_0    pypi
multidict                 5.1.0            py38h27cfd23_2
mysql-common              8.0.25               ha770c72_0    conda-forge
mysql-libs                8.0.25               h935591d_0    conda-forge
ncurses                   6.2                  he6710b0_1
nettle                    3.7.3                hbbd107a_1
networkx                  2.6.3                    pypi_0    pypi
ninja                     1.10.2               hff7bd54_1
nspr                      4.30                 h9c3ff4c_0    conda-forge
nss                       3.67                 hb5efdd6_0    conda-forge
numexpr                   2.7.3            py38h22e1b3c_1
numpy                     1.19.5                   pypi_0    pypi
numpy-base                1.20.3           py38h74d4b33_0
nvidia-ml-py3             7.352.0                  pypi_0    pypi
oauthlib                  3.1.1              pyhd3eb1b0_0
olefile                   0.46               pyhd3eb1b0_0
opencensus                0.7.13                   pypi_0    pypi
opencensus-context        0.1.2                    pypi_0    pypi
openh264                  2.1.0                hd408876_0
openssl                   1.1.1o               h7f8727e_0
opt-einsum                3.3.0                    pypi_0    pypi
packaging                 21.0                     pypi_0    pypi
pandas                    1.3.2            py38h8c16a72_0
pango                     1.48.5               hb8ff022_0    conda-forge
pcre                      8.45                 h295c915_0
pillow                    8.3.1            py38h5aabda8_0
pip                       21.2.2           py38h06a4308_0
pixman                    0.40.0               h7b6447c_0
platformdirs              2.5.2                    pypi_0    pypi
pluggy                    1.0.0                    pypi_0    pypi
prometheus-client         0.11.0                   pypi_0    pypi
protobuf                  3.17.3                   pypi_0    pypi
psutil                    5.8.0                    pypi_0    pypi
py                        1.10.0                   pypi_0    pypi
py-spy                    0.3.9                    pypi_0    pypi
pyasn1                    0.4.8              pyhd3eb1b0_0
pyasn1-modules            0.2.8                    pypi_0    pypi
pycparser                 2.20                       py_2
pyglet                    1.5.16           py38h578d9bd_0    conda-forge
pyjwt                     2.1.0            py38h06a4308_0
pyopenssl                 20.0.1             pyhd3eb1b0_1
pyparsing                 2.4.7                    pypi_0    pypi
pyrsistent                0.18.0                   pypi_0    pypi
pysocks                   1.7.1            py38h06a4308_0
pytest                    6.2.5                    pypi_0    pypi
python                    3.8.5                h7579374_1
python-dateutil           2.8.2              pyhd3eb1b0_0
python_abi                3.8                      2_cp38    conda-forge
pytorch                   1.10.2          py3.8_cuda11.3_cudnn8.2.0_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2021.1             pyhd3eb1b0_0
pywavelets                1.1.1                    pypi_0    pypi
pyyaml                    5.4.1                    pypi_0    pypi
qt                        5.12.9               hda022c4_4    conda-forge
ray                       1.9.2                    pypi_0    pypi
readline                  8.1                  h27cfd23_0
redis                     3.5.3                    pypi_0    pypi
requests                  2.26.0             pyhd3eb1b0_0
requests-oauthlib         1.3.0                      py_0
rsa                       4.7.2              pyhd3eb1b0_1
scikit-image              0.18.3                   pypi_0    pypi
scikit-learn              1.0.2            py38h51133e4_1
scipy                     1.7.1            py38h292c36d_2
setuptools                58.0.4           py38h06a4308_0
six                       1.15.0                   pypi_0    pypi
smart-open                5.2.1                    pypi_0    pypi
sortedcontainers          2.4.0                    pypi_0    pypi
sqlite                    3.36.0               hc218d9a_0
tabulate                  0.8.9                    pypi_0    pypi
tensorboard               2.6.0                    pypi_0    pypi
tensorboard-data-server   0.6.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.0                    pypi_0    pypi
tensorboardx              2.4                      pypi_0    pypi
tensorflow                2.6.0                    pypi_0    pypi
tensorflow-estimator      2.6.0                    pypi_0    pypi
termcolor                 1.1.0                    pypi_0    pypi
threadpoolctl             2.2.0              pyh0d69192_0
tifffile                  2021.8.30                pypi_0    pypi
tk                        8.6.10               hbc83047_0
toml                      0.10.2                   pypi_0    pypi
torchaudio                0.10.2               py38_cu113    pytorch
torchvision               0.11.3               py38_cu113    pytorch
typing-extensions         3.7.4.3                  pypi_0    pypi
typing_extensions         3.10.0.2           pyh06a4308_0
urllib3                   1.26.6             pyhd3eb1b0_1
virtualenv                20.15.0                  pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
werkzeug                  2.0.1              pyhd3eb1b0_0
wheel                     0.37.0             pyhd3eb1b0_1
wrapt                     1.12.1                   pypi_0    pypi
x264                      1!157.20191217       h7b6447c_0
xdot                      1.2                      pypi_0    pypi
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.2                h470a237_5    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-libxt                1.2.1                h7f98852_2    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h27cfd23_1007
xz                        5.2.5                h7b6447c_0
yarl                      1.6.3            py38h27cfd23_0
zipp                      3.5.0              pyhd3eb1b0_0
zlib                      1.2.11               h7b6447c_3
zstd                      1.4.9                haebb681_0

CUDA Version: 11.6 OS: Linux 4.19.0-18-amd64 #1 SMP Debian 4.19.208-1 (2021-09-29) x86_64 GNU/Linux

Reproduction script

import random
import unittest

import gym
import numpy as np
import ray
import torch
from ray.rllib.agents.dqn import DQNTrainer as Dqn
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import register_env

def make_simple_serving(multiagent, superclass):
    class SimpleServe(superclass):
        def __init__(self, env):
            superclass.__init__(self, env.action_space, env.observation_space)
            self.env = env

        def run(self):
            eid = self.start_episode()
            obs = self.env.reset()
            while True:
                action = self.get_action(eid, obs)
                obs, reward, done, info = self.env.step(action)
                if multiagent:
                    self.log_returns(eid, reward)
                else:
                    self.log_returns(eid, reward, info=info)
                if done:
                    print("Ended episode", eid)
                    self.end_episode(eid, obs)
                    obs = self.env.reset()
                    eid = self.start_episode()

    return SimpleServe

# generate & register SimpleServing class
SimpleServing = make_simple_serving(False, ExternalEnv)

class PartOffPolicyServing(ExternalEnv):
    def __init__(self, env, off_pol_frac):
        ExternalEnv.__init__(self, env.action_space, env.observation_space)
        self.env = env
        self.off_pol_frac = off_pol_frac
        self.rs = np.random.RandomState(seed=1)

    def run(self):
        eid = self.start_episode()
        obs = self.env.reset()
        while True:
            if self.rs.random() < self.off_pol_frac:
                action = self.env.action_space.sample()
                self.log_action(eid, obs, action)
            else:
                action = self.get_action(eid, obs)
            obs, reward, done, info = self.env.step(action)
            self.log_returns(eid, reward, info=info)
            if done:
                self.end_episode(eid, obs)
                obs = self.env.reset()
                eid = self.start_episode()

class SimpleOffPolicyServing(ExternalEnv):
    def __init__(self, env, fixed_action):
        ExternalEnv.__init__(self, env.action_space, env.observation_space)
        self.env = env
        self.fixed_action = fixed_action

    def run(self):
        eid = self.start_episode()
        obs = self.env.reset()
        while True:
            action = self.fixed_action
            self.log_action(eid, obs, action)
            obs, reward, done, info = self.env.step(action)
            self.log_returns(eid, reward, info=info)
            if done:
                self.end_episode(eid, obs)
                obs = self.env.reset()
                eid = self.start_episode()

class TestExternalEnv(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        ray.init(ignore_reinit_error=True)

    @classmethod
    def tearDownClass(cls) -> None:
        ray.shutdown()

    def test_train_cartpole_off_policy(self):
        print("Starting tests")
        register_env(
            "test3",
            lambda _: PartOffPolicyServing(gym.make("CartPole-v0"), off_pol_frac=0.2),
        )
        config = {
            "num_workers": 0,
            "exploration_config": {"epsilon_timesteps": 100},
            "seed": 1
        }
        torch.manual_seed(1)
        np.random.seed(1)
        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            dqn = Dqn(env="test3", config=config)
            reached = False
            for i in range(50):
                random.seed(1)
                result = dqn.train()
                r_int = random.randint(0, 2**32)   # I would expect this to be the same integer in every run if everything is deterministic
                print(
                    "Iteration {}, reward {}, timesteps {}, rnum {}".format(
                        i, result["episode_reward_mean"], result["timesteps_total"], r_int
                    )
                )
                if result["episode_reward_mean"] >= 80:
                    reached = True
                    break
            if not reached:
                raise Exception("failed to improve reward")

if __name__ == "__main__":
    import pytest
    import sys

    TestExternalEnv.setUpClass()
    t = TestExternalEnv()
    t.test_train_cartpole_off_policy()
    TestExternalEnv.tearDownClass()

Issue Severity

High: It blocks me from completing my task.

kouroshHakha commented 2 years ago

We are aware of the non-determinism issue with rllib and is on our todo list to figure out. Thanks for pointing it out.

utkarshp commented 2 years ago

Thanks! Do you know if it's just a matter of using the global random state throughout the code vs using a separate random state in every class/file? I can try this out myself (on my end. Also happy to contribute if allowed), but somehow I think there might be something more going on.

kouroshHakha commented 2 years ago

Hi @utkarshp, I actually investigated the problem a little bit and it turns out reproducibility should not be an issue anymore as long as you use tune (across all different resource specifications i.e. gpu, cpu, num_worker > 0, etc.). RLlib should support reproducible experimentation as long as the environment is deterministic. You can checkout https://github.com/ray-project/ray/blob/master/rllib/examples/deterministic_training.py to see the example.

I have also tested DQN on a small deterministic cartpole example and it works fine. The code is shared below:

import unittest

import gym
import ray
from ray.tune.registry import register_env
from ray import tune

from ray.rllib.algorithms.dqn import DQNConfig, DQN

class DeterministicCartPole(gym.Env):

    def __init__(self, seed=0):
        self.env = gym.make("CartPole-v0")
        self.env.seed(seed)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

    def reset(self):
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

seed = 0
print(f"Starting tests with seed = {seed}")
register_env(
    "deterministic_env",
    lambda _: DeterministicCartPole(seed=seed),
)
config = (
    DQNConfig()
    .environment(env="deterministic_env")
    .resources(num_gpus=1)
    .debugging(seed=seed)
    .rollouts(num_rollout_workers=0)
    .framework("torch")
)

tune.run(
    DQN,
    name="DQN_DETERMINISTIC_CARTPOLE",
    config=config.to_dict(),
    stop={"timesteps_total": 1e4},
)
utkarshp commented 2 years ago

This is amazing! I just tried using tune in my example and things seem to be deterministic! Just out of curiosity, what is it about using tune that makes things deterministic like this? Does it somehow force RLLib to use some other random state? Thanks a lot for your help @kouroshHakha. I am not sure if I should close this issue or not, so leaving it open for now.

For anyone that finds this issue later, and like me, hasn't used tune before, I made the following changes after the definition of the SimpleOffPolicyServing class in my example to get a similar run:

class MyCallback(Callback):
    def on_trial_result(self, iteration: int, trials: List["Trial"],
                        trial: "Trial", result: Dict, **info):
        r_int = random.randint(0, 2 ** 32)
        print(
            "Iteration {}, reward {}, timesteps {}, rnum {}".format(
                iteration, result["episode_reward_mean"], result["timesteps_total"], r_int
            )
        )
        random.seed(1)

def stopper(_, result):
    return result["episode_reward_mean"] >= 80 or result["training_iteration"] >= 50

class TestExternalEnv(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        ray.init(ignore_reinit_error=True)

    @classmethod
    def tearDownClass(cls) -> None:
        ray.shutdown()

    def test_train_cartpole_off_policy(self):
        print("Starting tests")
        register_env(
            "test3",
            lambda _: PartOffPolicyServing(gym.make("CartPole-v0"), off_pol_frac=0.2),
        )
        config = {
            "num_workers": 0,
            "exploration_config": {"epsilon_timesteps": 100},
            "env": "test3",
            "seed": 1
        }
        torch.manual_seed(1)
        np.random.seed(1)
        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            # dqn = Dqn(env="test3", config=config)
            tune.run("DQN", config=config, callbacks=[MyCallback()], stop=stopper)
            # if result["episode_reward_mean"] < 80:
            #     raise Exception("failed to improve reward")

The code after this is unchanged. I see that the generated random integers are always the same. There is some difference in the generated output, namely the output is printed every 3 iterations, and I see a lot more metrics printed. I suppose I need to tune (lol) the arguments a bit more to get an output that is exactly the same as before.

kouroshHakha commented 2 years ago

So I have tried this with .train() as well and it is still reproducible, here is the exact code:

import unittest

import gym
import ray
from ray.tune.registry import register_env
from ray import tune

from ray.rllib.algorithms.dqn import DQNConfig, DQN

class DeterministicCartPole(gym.Env):

    def __init__(self, seed=0):
        self.env = gym.make("CartPole-v0")
        self.env.seed(seed)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

    def reset(self):
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

seed = 0
print(f"Starting tests with seed = {seed}")
register_env(
    "deterministic_env",
    lambda _: DeterministicCartPole(seed=seed),
)
config = (
    DQNConfig()
    .environment(env="deterministic_env")
    .resources(num_gpus=1)
    .debugging(seed=seed)
    .rollouts(num_rollout_workers=8) # this for me has caused repro issues
    .reporting(min_time_s_per_iteration=0) # This line is very important
    .framework("torch")
)

# train() call 
algo = config.build()
for i in range(3):
    print(f'//// iteration {i}')
    results = algo.train()
print(results['episode_reward_mean'])

# tune.run call
tune.run(
            DQN,
            name="DQN_DET_CARTPOLE",
            config=config.to_dict(),
            stop={"training_iteration": 3}
)

In the above code both methods should produce the same episode_reward_mean after three iterations. Your questions actually brought up some good points that I want to clarify here:

If you care about reproducibility you have to make sure that there is no stopping condition that is set based on wall-clock time. For example the above code snippet would have not worked if min_time_s_per_iteration was left at default value of 1 (The default is set in SimpleQ's config object which DQN inherits from). This means that the algorithm should have waited at least 1 second per each iteration before moving to another iteration. Therefore, you have to wait more even if the min_train_timesteps_per_iteration or min_sample_timesteps_per_iteration are reached. This causes a perturbation in the random state at some point during training so you may still see differences between runs.