ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.01k stars 5.59k forks source link

[Rllib] InvalidArgumentError: cannot compute ConcatV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor #36364

Open xuanquang1999 opened 1 year ago

xuanquang1999 commented 1 year ago

What happened + What you expected to happen

The problem occurred when training the Soft-Actor Critic (SAC) model with TensorFlow2 on the Hopper environment of Gymnasium.

config = (
    SACConfig()
    .environment(env="Hopper-v4")
    .framework("tf2")
)
algo = config.build()

Instead of compiling normally, the script threw the following error:

Traceback (most recent call last):
  File "/Users/macbookpro/Documents/autonomous/generative/markov/test-rllib/github_report.py", line 9, in <module>
    algo = config.build()
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1071, in build
    return algo_class(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac.py", line 354, in __init__
    super().__init__(*args, **kwargs)
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 475, in __init__
    super().__init__(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 170, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 601, in setup
    self.workers = WorkerSet(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 172, in __init__
    self._setup(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 262, in _setup
    self._local_worker = self._make_worker(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 967, in _make_worker
    worker = cls(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 738, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
    self._build_policy_map(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 139, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py", line 470, in __init__
    self._initialize_loss_from_dummy_batch(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1487, in _initialize_loss_from_dummy_batch
    self._loss(self, self.model, self.dist_class, train_batch)
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_policy.py", line 333, in sac_actor_critic_loss
    q_t, _ = model.get_q_values(
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py", line 212, in get_q_values
    return self._get_q_value(model_out, actions, self.q_net)
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py", line 245, in _get_q_value
    input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 7262, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute ConcatV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:ConcatV2] name: concat

The error happened because TensorFlow tried to concat two tensor of different datatype ( input_dict = {"obs": tf.concat([model_out, actions], axis=-1)} ). The model_out tensor has float64 datatype (after more backtracking, the datatype seems to be inferred from the observation space), while the actions tensor always has float32 datatype (as it is always casted to float32 by the following code).

q_t, _ = model.get_q_values(
    model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
)

I have tried to specify the observation_space to use float32 datatype, but the error persisted.

config = (
    SACConfig()
    .environment(
        env="Hopper-v4",
        observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float32))
    .framework("tf2")
)

For now, my workaround is:

  1. Modify the library code to cast the model_out tensor to float32. This solution worked, but it seems rather hacky and I don't know if this fix won't break the model when trained on other environment.
  2. Use PyTorch instead of TF2. This is not ideal because I need to inferencing the model on web frontend, and there is currently no library to effortlessly run PyTorch model on web. Converting PyTorch to TF.js is rather complicated (PyTorch -> ONNX -> TF python -> TF.js), and most tools for ONNX -> TF python conversion is no longer maintained.

Versions / Dependencies

Ray: 2.5.0 OS: macOS Monterey (12.6) Python: 3.9.16

Reproduction script

from ray.rllib.algorithms.sac import SACConfig
config = (
    SACConfig()
    .environment(env="Hopper-v4")
    .framework("tf2")
)
algo = config.build()

Issue Severity

Low: It annoys or frustrates me.

Rohan138 commented 1 year ago

What version of gymnasium are you using? Gymnasium recently changed the discrete space's dtype from float to np.float, which can result in downstream issues.