ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.99k stars 5.78k forks source link

[RLlib] KeyError: 'obs' due to empty SampleBatch #36056

Open radillus opened 1 year ago

radillus commented 1 year ago

What happened + What you expected to happen

While using custom model, I have encountered KeyError: 'obs', and it seems caused by a empty sample batch

2023-06-04 01:28:55,511 INFO worker.py:1625 -- Started a local Ray instance.
(RolloutWorker pid=44531) 2023-06-04 01:28:58,256       WARNING env.py:155 -- Your env doesn't have a .spec.max_episode_steps attribute. Your horizon will default to infinity, and your environment will not be reset.
2023-06-04 01:28:59,045 WARNING util.py:67 -- Install gputil for GPU system monitoring.
2023-06-04 01:29:13,231 ERROR actor_manager.py:507 -- Ray error, taking actor 1 out of service. The actor died unexpectedly before finishing this task.
2023-06-04 01:29:13,231 ERROR actor_manager.py:507 -- Ray error, taking actor 2 out of service. The actor died unexpectedly before finishing this task.
Traceback (most recent call last):
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1218, in _worker
    self.loss(model, self.dist_class, sample_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 87, in loss
    logits, state = model(train_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 248, in __call__
    input_dict["obs"], self.obs_space, self.framework
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/sample_batch.py", line 866, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'obs'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/radillus_wsl/myprojects/metalen/dev_test/scalingtest.py", line 119, in <module>
    print(algo.train())  # 3. train it,
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 384, in train
    raise skipped from exception_cause(skipped)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 381, in train
    result = self.step()
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 792, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 2811, in _run_one_training_iteration
    results = self.training_step()
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 432, in training_step
    train_results = multi_gpu_train_one_step(self, train_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/execution/train_ops.py", line 163, in multi_gpu_train_one_step
    results = policy.learn_on_loaded_batch(
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 779, in learn_on_loaded_batch
    return self.learn_on_batch(batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 665, in learn_on_batch
    grads, fetches = self.compute_gradients(postprocessed_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 869, in compute_gradients
    tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1303, in _multi_gpu_parallel_grad_calc
    raise last_result[0] from last_result[1]
ValueError: obs
 tracebackTraceback (most recent call last):
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1218, in _worker
    self.loss(model, self.dist_class, sample_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 87, in loss
    logits, state = model(train_batch)
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 248, in __call__
    input_dict["obs"], self.obs_space, self.framework
  File "/home/radillus_wsl/mambaforge/envs/mini/lib/python3.10/site-packages/ray/rllib/policy/sample_batch.py", line 866, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'obs'

In tower 0 on device cpu

and I add print("input_dict", input_dict) in ray.rllib.models.modelv2.__call__ method it prints input_dict SampleBatch(0: []) rather than normal input_dict SampleBatch(1: ['obs'])

Versions / Dependencies

python=3.10.11, ray=2.4.0, Ubuntu22.04(wsl2)

Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
aiosignal                 1.3.1                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2023.5.7             hbcca054_0    conda-forge
certifi                   2023.5.7                 pypi_0    pypi
charset-normalizer        3.1.0                    pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
cloudpickle               2.2.1                    pypi_0    pypi
cmake                     3.26.3                   pypi_0    pypi
distlib                   0.3.6                    pypi_0    pypi
dm-tree                   0.1.8                    pypi_0    pypi
filelock                  3.12.0                   pypi_0    pypi
frozenlist                1.3.3                    pypi_0    pypi
grpcio                    1.51.3                   pypi_0    pypi
gymnasium                 0.26.3                   pypi_0    pypi
gymnasium-notices         0.0.1                    pypi_0    pypi
idna                      3.4                      pypi_0    pypi
imageio                   2.28.1                   pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
joblib                    1.2.0                    pypi_0    pypi
jsonschema                4.17.3                   pypi_0    pypi
lazy-loader               0.2                      pypi_0    pypi
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgomp                   12.2.0              h65d4601_19    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
lit                       16.0.3                   pypi_0    pypi
lz4                       4.3.2                    pypi_0    pypi
markdown-it-py            2.2.0                    pypi_0    pypi
markupsafe                2.1.2                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.0.5                    pypi_0    pypi
ncurses                   6.3                  h27087fc_1    conda-forge
networkx                  3.1                      pypi_0    pypi
numpy                     1.24.3                   pypi_0    pypi
nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
nvidia-cuda-cupti-cu11    11.7.101                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
openssl                   3.1.1                hd590300_1    conda-forge
packaging                 23.1                     pypi_0    pypi
pandas                    2.0.1                    pypi_0    pypi
pillow                    9.5.0                    pypi_0    pypi
pip                       23.1.2             pyhd8ed1ab_0    conda-forge
platformdirs              3.5.1                    pypi_0    pypi
protobuf                  3.20.3                   pypi_0    pypi
pygments                  2.15.1                   pypi_0    pypi
pyrsistent                0.19.3                   pypi_0    pypi
python                    3.10.11         he550d4f_0_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pywavelets                1.4.1                    pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
ray                       2.4.0                    pypi_0    pypi
readline                  8.2                  h8228510_1    conda-forge
requests                  2.30.0                   pypi_0    pypi
rich                      13.3.5                   pypi_0    pypi
scikit-image              0.20.0                   pypi_0    pypi
scikit-learn              1.2.2                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
setuptools                67.7.2             pyhd8ed1ab_0    conda-forge
six                       1.16.0                   pypi_0    pypi
sympy                     1.12                     pypi_0    pypi
tabulate                  0.9.0                    pypi_0    pypi
tensorboardx              2.6                      pypi_0    pypi
threadpoolctl             3.1.0                    pypi_0    pypi
tifffile                  2023.4.12                pypi_0    pypi
tk                        8.6.12               h27826a3_0    conda-forge
torch                     2.0.1                    pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typer                     0.9.0                    pypi_0    pypi
typing-extensions         4.5.0                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
urllib3                   2.0.2                    pypi_0    pypi
virtualenv                20.21.0                  pypi_0    pypi
wheel                     0.40.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge

Reproduction script

import gymnasium as gym
from gymnasium.spaces import Box
from scipy.ndimage import zoom
import torch.nn as nn
import torch
import numpy as np
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig

class MyEnv(gym.Env):
    # A simple env for testing, target is to modify a matrix 'now_mat' to approch the target matrix
    def __init__(self, env_config) -> None:
        self.origin_size = env_config['origin_size']
        self.to_size = env_config['to_size']
        self.observation_space = Box(low=-1, high=1, shape=(self.to_size, self.to_size, 2))
        self.action_space = Box(low=-1, high=1, shape=(3,))
        self.target_mat = None
        self.now_mat = None
        self.now_scale_mat = None
        self.step_count = 0
        self.max_step = (2 * env_config['origin_size'])**2

    def _get_obs(self):
        a = np.stack([self.target_mat, self.now_scale_mat], axis=2)
        return a

    def _get_info(self):
        return {}

    def reset(self,  *, seed=None, options=None):
        self.step_count = 0
        self.target_mat = np.random.rand(self.origin_size, self.origin_size)*2-1
        self.target_mat = zoom(self.target_mat, self.to_size/self.origin_size)
        self.target_mat = np.clip(self.target_mat, -1, 1)
        self.now_mat = np.zeros((self.origin_size, self.origin_size))
        self.now_mat = np.clip(self.now_mat, -1, 1)
        self.now_scale_mat = zoom(self.now_mat, self.to_size/self.origin_size)
        self.now_scale_mat = np.clip(self.now_scale_mat, -1, 1)
        return self._get_obs(), self._get_info()

    def _to_idx(self, value):
        return min(int((value+1)/2*self.origin_size),self.origin_size-1)

    def step(self, action):
        x_idx = self._to_idx(action[0])
        y_idx = self._to_idx(action[1])
        self.now_mat[x_idx, y_idx] = action[2]
        self.now_mat = np.clip(self.now_mat, -1, 1)
        now_scale_mat = zoom(self.now_mat, self.to_size/self.origin_size)
        now_scale_mat = np.clip(now_scale_mat, -1, 1)
        now_loss = np.abs(self.target_mat - self.now_scale_mat)
        reward = np.mean(now_loss - np.abs(self.target_mat - now_scale_mat))
        self.now_scale_mat = now_scale_mat
        self.step_count += 1
        return self._get_obs(), reward, self.step_count > self.max_step, np.mean(now_loss) < 0.1, self._get_info()

class SimpleConv(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config,
        name: str,
    ):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        now_channel, now_size = 2, obs_space.shape[1]
        self.convs = nn.ModuleList()
        while (now_size-5)//2+1 > 1:
            self.convs.append(nn.Conv2d(now_channel, now_channel*2, 5, 2))
            self.convs.append(nn.ReLU())
            now_channel *= 2
            now_size = (now_size-5)//2+1
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(now_channel*now_size*now_size, 128), nn.ReLU(), nn.Linear(128, num_outputs))
    def forward(self, input_dict, state, seq_lens):
        x = input_dict['obs'].float().permute(0, 3, 1, 2)
        mat_loss = torch.mean(torch.abs(x[:,0,...]-x[:,1,...]), dim=(1,2)) # 0~2
        self._value = 1 - mat_loss # -1~1
        for layer in self.convs:
            x = layer(x)
        x = self.fc(x)
        return x, state
    def value_function(self):
        return self._value

ModelCatalog.register_custom_model("simpleconv", SimpleConv)

model_config_dict = {
    "custom_model": "simpleconv",
}

env_config = {
    'origin_size': 100,
    'to_size': 500,
}

algo = PPOConfig()
algo = algo.training(
    model = model_config_dict,
)
algo = algo.environment(
    env = MyEnv,
    env_config=env_config,
)
algo = algo.resources(num_gpus=0)
algo = algo.build()
for _ in range(2):
    print(algo.train())
algo.evaluate()

Issue Severity

High: It blocks me from completing my task.

Haeyeon-Choi commented 1 year ago

Did you solve the problem? I get the same error.

radillus commented 1 year ago

Did you solve the problem? I get the same error.

No, I can not find the solution, maybe you want to try the new RLModule API since ModelV2 API will be superseded.

BTW, I tried many other rl libs, but they all have some problems with custom enviroment and policy.

Haeyeon-Choi commented 1 year ago

Did you solve the problem? I get the same error.

No, I can not find the solution, maybe you want to try the new RLModule API since ModelV2 API will be superseded.

BTW, I tried many other rl libs, but they all have some problems with custom enviroment and policy.

Thanks for your reply! Hope it will be fixed soon.

Haeyeon-Choi commented 1 year ago

Did you solve the problem? I get the same error.

No, I can not find the solution, maybe you want to try the new RLModule API since ModelV2 API will be superseded. BTW, I tried many other rl libs, but they all have some problems with custom enviroment and policy.

Thanks for your reply! Hope it will be fixed soon.

Hi! I found that my issue was due to some termination conditions I included in the run file (rather than the environment file). Also, I wrote a run file by referring to action_masking.py and it is currently being trained normally. I hope it helps if the problem still persists.