ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
31.98k stars 5.45k forks source link

[RLlib] Off Policy Estimation Issues with CQL and MARWIL #37818

Open jayanthnair opened 11 months ago

jayanthnair commented 11 months ago

What happened + What you expected to happen

I am trying to get an offline RL agent trained on data from a custom environment. I created the data using a PPO agent and then split it into training and evaluation datasets to work with OPE (trying not to involve the custom environment in evaluation at all).

Below is a sample experience from the training dataset: {"type": "SampleBatch", "obs": [[11946.630254387865, 9.449011087417592, 97.11496531963348]], "new_obs": [[11948.201805353174, 9.507054090499889, 97.127750515938]], "actions": [[0.8743069767951961, 0.15215331315994202]], "rewards": [2.975874900817871], "terminateds": [false], "truncateds": [false], "agent_index": [0], "eps_id": [896748521681811842], "unroll_id": [91], "action_prob": [0.14308243989944402], "prev_actions": [[-0.025052666664123, 0.046627640724182004]], "prev_rewards": [2.977622509002685], "action_logp": [-1.944334268569946]}

And below is a sample batch from the evaluation dataset (edited for space) - I can provide the full files if needed: {"type": "SampleBatch", "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ....], "eps_id": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ....], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...], "obs": [[11729.91782426835, 9.234923124313388, 104.78036105632799], [11617.970019578941, 9.234923124313388, 103.78037393093123], [11646.095886826515, 9.172278642654497, 104.0316075086595], [11576.776653528224, 9.16328430175789, 103.41238975524925], .... ], "action_prob": [0.0056711747311050005, 0.168887481093406, 0.09625877439975701, 0.06684957444667801, .... ], "rewards": [2.917489528656006, 2.914583444595337, 2.925419330596924, 2.936786651611328, 2.9305403232574463, ... ], "new_obs": [[11617.970019578941, 9.234923124313388, 103.78037393093123], [11646.095886826515, 9.172278642654497, 104.0316075086595], [11576.776653528224, 9.16328430175789, 103.41238975524925], ....], "terminateds": [false, false, false, false, .. true], "truncateds": [false, false, false, false, ... ,false,], "unroll_id": [0, 0, 0, 0, ...], "prev_actions": [[0.615730047225952, 0.08490172028541501], [-0.217299610376358, -0.17831170558929402], [-0.40819126367568903, 0.31043612957000705], [0.185333654284477, -0.19886156916618303], [0.400110274553298, 0.9934983253479001], [0.22803078591823503, -0.30718606710433904], ...], "prev_rewards": [2.912662744522094, 2.907510757446289, 2.910818099975586, 2.904420614242553, 2.9175233840942383, 2.9095067977905273, 2.917701244354248, 2.897789478302002, 2.890471935272217, 2.900053024291992, 2.90383243560791, ...], "action_logp": [-1.320018529891967, -1.220567345619201, -1.022270679473877, -1.143208742141723, ...]}

When I try to run CQL using the following block of code,

from ray.rllib.utils.framework import try_import_torch

from ray.rllib.offline.estimators import WeightedImportanceSampling
import gymnasium.spaces as spaces
import numpy as np

torch, _ = try_import_torch()

observation_space = spaces.Box(
            low=np.array([0, 0, 0]),
            high=np.array([30000, 200, 500]),
            shape=(3,),
            dtype=np.float32,
        )
action_space = spaces.Box(
            low=np.array([0, 0]),
            high=np.array([10, 200]),
            shape=(2,),
            dtype=np.float32,
        )
if __name__ == "__main__":
    config2 = (
        cql.CQLConfig()
        .training(train_batch_size=512)
        .debugging(log_level="INFO")
        .framework(framework="torch")
        .rollouts(num_rollout_workers=0,
                  num_envs_per_worker= 20)
        .reporting(min_train_timesteps_per_iteration=1000)
        .environment( env=None, observation_space=observation_space, action_space=action_space)
        .offline_data(input_ = "test_training_data2_unsquashed.json", 
                      actions_in_input_normalized= True)
        .evaluation(
            evaluation_num_workers=1,
            evaluation_interval=10,
            evaluation_duration=5,
            evaluation_duration_unit="episodes",
            evaluation_parallel_to_training=False,
            evaluation_config={"input": "test_evaluation_data2_unsquashed.json","explore": False},
            off_policy_estimation_methods={
                "wis": {"type": WeightedImportanceSampling},
            },
        )
    )      

    algo2 = config2.build()
    for i in range(100):
        algo2.train()

# Get policy and model.
cql_policy = algo2.get_policy()
cql_model = cql_policy.model

I get NaN values for the v_target during evaluation. Debugging in the code, it seems the computed new_prob value by the off policy estimator is 0. I am unable to figure out why this is the case for CQL. When I use MARWIL instead, the new_prob values can also get really small, however not fully zero, so I get values back for v_target.

Note: this is a continuous action space problem with 2 actions. I have not normalized the states in the training/evaluation data but the actions are normalized by ray (the data is obtained during training a PPO policy).

Additionally, for datasets where the action_prob values do not exist, are there any guidelines on how to generate them?

Versions / Dependencies


absl-py==1.4.0
adal==1.2.7
aiohttp==3.8.4
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
ansicon==1.89.0
anyio==3.7.0
argcomplete==2.1.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-timeout==4.0.2        
attrs==23.1.0
azure-common==1.1.28
azure-core==1.27.1
azure-graphrbac==0.61.1
azure-identity==1.13.0
azure-mgmt-authorization==3.0.0
azure-mgmt-containerregistry==10.1.0
azure-mgmt-core==1.4.0
azure-mgmt-keyvault==10.2.2
azure-mgmt-resource==21.2.1
azure-mgmt-storage==20.1.0
azure-storage-blob==12.13.0
azureml-core==1.48.0
azureml-dataprep==4.8.6
azureml-dataprep-native==38.0.0
azureml-dataprep-rslex==2.15.2
azureml-dataset-runtime==1.48.0
azureml-defaults==1.48.0
azureml-inference-server-http==0.7.7
azureml-mlflow==1.52.0
backcall==0.2.0
backports.tempfile==1.0
backports.weakref==1.0.post1
bcrypt==4.0.1
beautifulsoup4==4.12.2
bleach==6.0.0
blessed==1.20.0
blinker==1.6.2
cachetools==5.3.1
certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi
cffi==1.15.1
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
colorama==0.4.6
colorful==0.5.5
comm==0.1.3
contextlib2==21.6.0
contourpy==1.1.0
cryptography==38.0.4
cycler==0.11.0
databricks-cli==0.17.7
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
distlib==0.3.6
distro==1.8.0
dm-tree==0.1.8
docker==6.1.3
dotnetcore2==3.1.23
entrypoints==0.4
exceptiongroup==1.1.2
executing==1.2.0
fastapi==0.99.1
fastjsonschema==2.17.1
filelock==3.12.2
Flask==2.3.2
Flask-Cors==3.0.10
fonttools==4.40.0
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.6.0
fusepy==3.0.1
gast==0.5.4
gitdb==4.0.10
GitPython==3.1.31
google-api-core==2.11.1
google-auth==2.21.0
google-auth-oauthlib==1.0.0
googleapis-common-protos==1.59.1
gpustat==1.1
grpcio==1.56.0
Gymnasium==0.26.3
gymnasium-notices==0.0.1
h11==0.14.0
humanfriendly==10.0
idna==3.4
imageio==2.31.1
importlib-metadata==6.7.0
inference-schema==1.5.1
ipykernel==6.24.0
ipython==8.14.0
ipython-genutils==0.2.0
ipywidgets==8.0.6
isodate==0.6.1
isoduration==20.11.0
itsdangerous==2.1.2
jedi==0.18.2
jeepney==0.8.0
Jinja2==3.1.2
jinxed==1.2.0
jmespath==1.0.1
joblib==1.3.1
jsonpickle==2.2.0
jsonpointer==2.4
jsonschema==4.17.3
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.6.3
jupyter_client==8.3.0
jupyter_core==5.3.1
jupyter_server==2.7.0
jupyter_server_terminals==0.4.4
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.7
kiwisolver==1.4.4
knack==0.10.1
lazy_loader==0.3
lz4==4.3.2
Markdown==3.4.3
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.7.2
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.1
mlflow-skinny==2.4.1
mpmath==1.3.0
msal==1.22.0
msal-extensions==1.0.0
msgpack==1.0.5
msrest==0.7.1
msrestazure==0.6.4
multidict==6.0.4
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.6.0
nbformat==5.9.0
ndg-httpsclient==0.5.1
nest-asyncio==1.5.6
networkx==3.1
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.24.3
nvidia-ml-py==11.525.131
oauthlib==3.2.2
opencensus==0.11.2
opencensus-context==0.1.3
opencensus-ext-azure==1.1.9
overrides==7.3.1
packaging==21.3
pandas==2.0.2
pandocfilters==1.5.0
paramiko==2.12.0
parso==0.8.3
pathspec==0.11.1
pickleshare==0.7.5
Pillow==10.0.0
pkginfo==1.9.6
platformdirs==3.8.0
portalocker==2.7.0
prometheus-client==0.17.0
prompt-toolkit==3.0.38
protobuf==4.23.3
psutil==5.8.0
pure-eval==0.2.2
py-spy==0.3.14
pyarrow==6.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==1.10.10
Pygments==2.15.1
PyJWT==2.7.0
PyNaCl==1.5.0
pyOpenSSL==22.1.0
pyparsing==3.0.9
pyreadline3==3.4.1
pyrsistent==0.19.3
PySocks==1.7.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3
PyWavelets==1.4.1
pywin32==306
pywinpty==2.0.10
PyYAML==6.0
pyzmq==25.1.0
qtconsole==5.4.3
QtPy==2.3.1
ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-win_amd64.whl
ray-on-aml==0.2.4
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.4.2
rsa==4.9
scikit-image==0.21.0
scikit-learn==1.3.0
scipy==1.10.1
SecretStorage==3.3.3
Send2Trash==1.8.2
six==1.16.0
smart-open==6.3.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.4.1
sqlparse==0.4.4
stack-data==0.6.2
starlette==0.27.0
sympy==1.12
tabulate==0.9.0
tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorboardX==2.6.1
tensorflow-probability==0.20.1
terminado==0.17.1
threadpoolctl==3.1.0
tifffile==2023.4.12
tinycss2==1.2.1
torch==2.0.1
tornado==6.3.2
traitlets==5.9.0
typer==0.9.0
typing_extensions==4.7.1
tzdata==2023.3
uri-template==1.3.0
urllib3==1.26.16
uvicorn==0.22.0
virtualenv==20.21.0
waitress==2.1.2
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.1
Werkzeug==2.3.6
widgetsnbextension==4.0.7
wincertstore==0.2
wrapt==1.12.1
yarl==1.9.2
zipp==3.15.0```

### Reproduction script

provided above. please let me know if you need the data files to test this on your end.

### Issue Severity

High: It blocks me from completing my task.
jayanthnair commented 11 months ago

Update: When I normalize the observations, the NaN values go away. However, I am getting very high values for episode_p and w_t within the WIS estimator as given below (obtained by just adding print statements in my local installation of ray):

episode_p 8513.5696510565 w_t 2576.458318062613 reward 2.900228977203369 gamma 0.99 t 0 v_target 9.583427462548682 cumm_ips 192320911200.69736 timestep count 9.0 episode_p 14606767.416885767 w_t 21368990133.410816 reward 2.913301229476928 gamma 0.99 t 1 v_target 9.585398934952224 cumm_ips 7836036132843.913 timestep count 9.0 episode_p 45776432.43120361 w_t 870670681427.1014 reward 2.916510343551635 gamma 0.99 t 2 v_target 9.585549222131336 cumm_ips 1963095524560.4666 timestep count 9.0 episode_p 14223718.147090815 w_t 218121724951.16296 reward 2.920472860336303 gamma 0.99 t 3 v_target 9.585734009794283 cumm_ips 2.4059432465055018e+17 timestep count 9.0 episode_p 6169143.93639413 w_t 2.673270273895002e+16 reward 2.925758600234985 gamma 0.99 t 4 v_target 9.58573401044286 cumm_ips 2.377338644852594e+21 timestep count 9.0 episode_p 1507430783.8379612 w_t 2.641487383169549e+20 reward 2.914356231689453 gamma 0.99 t 5 v_target 9.585734010458676 cumm_ips 7.244313115403152e+20 timestep count 9.0 episode_p 1519812991.8283892 w_t 8.049236794892391e+19 reward 2.910939216613769 gamma 0.99 t 6 v_target 9.585734010510423 cumm_ips 4.466484846552509e+25 timestep count 9.0 episode_p 855726107297434.5 w_t 4.962760940613899e+24 reward 2.902001142501831 gamma 0.99 t 7 v_target 9.58573401097682 cumm_ips 5.8304359730186284e+32 timestep count 9.0 episode_p 189317414096918.97 w_t 6.47826219224292e+31 reward 2.901033878326416 gamma 0.99 t 8 v_target 9.58573401097682 cumm_ips 1.2079629429002098e+33 timestep count 9.0 episode_p 4.2566922684931155e+17 w_t 1.3421810476668997e+32 reward 2.8892805576324463 gamma 0.99 t 9 v_target 9.585734010976829 cumm_ips 2.0367400388900602e+37 timestep count 9.0 episode_p 9.346790091630546e+21 w_t 2.2630444876556225e+36 reward 2.877197504043579 gamma 0.99 t 10 v_target 9.58573401097684 cumm_ips 1.7342289655544342e+40 timestep count 9.0 episode_p 5.337144272736485e+25 w_t 1.92692107283826e+39 reward 2.864840984344482 gamma 0.99 t 11 v_target 9.58573401097691 cumm_ips 3.558234390228094e+40 timestep count 9.0 episode_p 2.4553697531815067e+25 w_t 3.953593766920104e+39 reward 2.8657515048980713 gamma 0.99 t 12 v_target 9.585734010976926 cumm_ips 2.7690568374387527e+44 timestep count 9.0 episode_p 1.4400965872859918e+26 w_t 3.076729819376392e+43 reward 2.849908590316772 gamma 0.99 t 13 v_target 9.585734010976926 cumm_ips 1.0188884136798612e+44 timestep count 9.0 episode_p 3.184339593046544e+26 w_t 1.132098237422068e+43 reward 2.864427328109741 gamma 0.99 t 14 v_target 9.585734010976926 cumm_ips 5.079109213010439e+47 timestep count 9.0 episode_p 2.0185371094092948e+30 w_t 5.64345468112271e+46 reward 2.882045269012451 gamma 0.99 t 15 v_target 9.585734010976926 cumm_ips 9.973615542662502e+46 timestep count 9.0 episode_p 8.113431364817752e+31 w_t 1.1081795047402781e+46 reward 2.88168716430664 gamma 0.99 t 16 v_target 9.585734010976944 cumm_ips 8.788924198556167e+51 timestep count 9.0 episode_p 4.662735403700911e+37 w_t 9.765471331729074e+50 reward 2.892638921737671 gamma 0.99 t 17 v_target 9.585734010977061 cumm_ips 2.5113676733933145e+51 timestep count 9.0 episode_p 1.570661161721349e+37 w_t 2.7904085259925717e+50 reward 2.899038314819336 gamma 0.99 t 18 v_target 9.585734010977198 cumm_ips 1.754255168578099e+57 timestep count 9.0 episode_p 4.227832085650053e+37 w_t 1.9491724095312212e+56 reward 2.899766683578491 gamma 0.99 t 19 v_target 9.585734010977198 cumm_ips 7.968121809024278e+60 timestep count 9.0 episode_p 3.1721174152117914e+37 w_t 8.853468676693641e+59 reward 2.8906691074371342 gamma 0.99 t 20 v_target 9.585734010977198 cumm_ips 2.3265677109506663e+69 timestep count 9.0 episode_p 3.858944420693612e+37 w_t 2.5850752343896293e+68 reward 2.879977226257324 gamma 0.99 t 21 v_target 9.585734010977198 cumm_ips 3.958864382045674e+78 timestep count 9.0 episode_p 1.1238869018284184e+40 w_t 4.3987382022729713e+77 reward 2.864849805831909 gamma 0.99 t 22 v_target 9.585734010977198 cumm_ips 4.664291173049579e+84 timestep count 9.0 episode_p 1.0947887531317158e+44 w_t 5.182545747832866e+83 reward 2.848375082015991 gamma 0.99 t 23 v_target 9.585734010977198 cumm_ips 1.3288531487849032e+94 timestep count 9.0 episode_p 2.1020298877355823e+43 w_t 1.4765034986498925e+93 reward 2.849180221557617 gamma 0.99 t 24 v_target 9.585734010977198 cumm_ips 2.881890477248469e+98 timestep count 9.0 episode_p 6.501886868012352e+42 w_t 3.2021005302760767e+97 reward 2.860762357711792 gamma 0.99 t 25 v_target 9.585734010977198 cumm_ips 3.623839328524335e+97 timestep count 9.0 episode_p 5.582186058899778e+45 w_t 4.026488142804817e+96 reward 2.849238872528076 gamma 0.99 t 26 v_target 9.585734010977198 cumm_ips 9.924114462230434e+101 timestep count 9.0 episode_p 3.052666642852503e+46 w_t 1.1026793846922706e+101 reward 2.835889339447021 gamma 0.99 t 27 v_target 9.585734010977198 cumm_ips 1.1134983953410481e+101 timestep count 9.0 episode_p 1.5953338460255271e+46 w_t 1.2372204392678312e+100 reward 2.849604606628418 gamma 0.99 t 28 v_target 9.585734010977198 cumm_ips 3.7565678852220665e+111 timestep count 9.0 episode_p 1.2442496382557476e+48 w_t 4.173964316913407e+110 reward 2.838562250137329 gamma 0.99 t 29 v_target 9.585734010977198 cumm_ips 1.0502064536917993e+111 timestep count 9.0 episode_p 3.574771502950882e+48 w_t 1.1668960596575548e+110 reward 2.846707820892334 gamma 0.99 t 30 v_target 9.585734010977198

These values are very close to causing overflow errors, and when the episodes are longer in the data, they do cause these errors. Are there some general guidelines available for OPE with CQL/MARWIL?

ArturNiederfahrenhorst commented 11 months ago

I can't find any obvious reason for now.

  1. Could you provide a full reproduction script?
  2. Could you stick to python/ray/rllib/algorithms/cql/tests/test_cql.py as close as possible, but using your dataset?

Just by intuition: If new_prob is zero or close to zero, CQL would basically never you the action at hand. Maybe the actions of your batch are squashed? Is see that the json is named test_training_data2_unsquashed, but the actions that I see within look very much squashed.