ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.99k stars 5.78k forks source link

[RLLib] Problem with SAC custom model -> input does not get flattened correctly by pre-processing #35547

Open uceeatz opened 1 year ago

uceeatz commented 1 year ago

What happened + What you expected to happen

I am trying to use the SAC algorithm with a custom model and env to do action masking. When it initialises, it seems to run okay, however then when it tries to initialise loss with a dummy batch, the observation fed in seems to not be flattened, i.e. "obs" is only of shape 6242, which is just the "obs" portion of the dict observation and not "obs" + "action_mask" which would flattened be 12482 length.

(RolloutWorker pid=3235253) 2023-05-19 13:33:44,467     ERROR worker.py:844 -- Exception raised in cre
ation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init_
_() (pid=3235253, ip=128.40.41.48, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0
x7f871aa828e0>)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/evaluation/rollout_worker.py", line 738, in __init__
(RolloutWorker pid=3235253)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
(RolloutWorker pid=3235253)     self._build_policy_map(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
(RolloutWorker pid=3235253)     new_policy = create_policy_for_framework(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/utils/policy.py", line 142, in create_policy_for_framework
(RolloutWorker pid=3235253)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/policy/policy_template.py", line 327, in __init__
(RolloutWorker pid=3235253)     self._initialize_loss_from_dummy_batch(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/policy/policy.py", line 1405, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=3235253)     actions, state_outs, extra_outs = self.compute_actions_from_input_dict
(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/policy/torch_policy.py", line 325, in compute_actions_from_input_dict
(RolloutWorker pid=3235253)     return self._compute_action_helper(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/utils/threading.py", line 24, in wrapper
(RolloutWorker pid=3235253)     return func(self, *a, **k)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/policy/torch_policy.py", line 981, in _compute_action_helper
(RolloutWorker pid=3235253)     dist_inputs, dist_class, state_out = self.action_distribution_fn(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/algorithms/sac/sac_torch_policy.py", line 165, in action_distribution_fn
(RolloutWorker pid=3235253)     action_dist_inputs, _ = model.get_action_model_outputs(model_out)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/algorithms/sac/sac_torch_model.py", line 313, in get_action_model_outputs
(RolloutWorker pid=3235253)     return self.action_model(model_out, state_in, seq_lens)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/models/modelv2.py", line 247, in __call__
(RolloutWorker pid=3235253)     restored["obs"] = restore_original_dimensions(
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/models/modelv2.py", line 411, in restore_original_dimensions
(RolloutWorker pid=3235253)     return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(RolloutWorker pid=3235253)   File "/home/uceeatz/miniconda3/envs/ilp/lib/python3.9/site-packages/ray/
rllib/models/modelv2.py", line 445, in _unpack_obs
(RolloutWorker pid=3235253)     raise ValueError(
(RolloutWorker pid=3235253) ValueError: Expected flattened obs shape of [..., 12482], got torch.Size([
32, 6242])

As the obs is fed through a model in action_distribution_fn in sac_torch_policy.py before hand, I am not sure what is being fed into the custom model. I have attached a zip with the files to reproduce the error below.

Versions / Dependencies

ray=2.4.0 networkx=2.8.4 gymnasium=0.26.3 torch=2.0.0

Reproduction script

python train_sac_offline.py train_sac_offline.zip

Issue Severity

High: It blocks me from completing my task.

Rohan138 commented 1 year ago

Could you provide more context on what your policy and action mask look like? Action masking is intended for masking actions-it should technically have nothing to do with observations.