ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.37k stars 5.65k forks source link

[RLlib] Tensor size mismatch error when training LSTM on CartPole environment with PPO and IMPALA in RLlib #47653

Open klae01 opened 3 weeks ago

klae01 commented 3 weeks ago

What happened + What you expected to happen

bug: When training an IMPALA agent with an LSTM model on the CartPole environment using RLlib, the training crashes with a RuntimeError due to a tensor size mismatch during the loss computation. The error message indicates that the sizes of two tensors do not match at a specific dimension.

expected behavior: The training should proceed without errors, allowing the IMPALA agent with an LSTM model to learn the CartPole environment successfully.

outputs:

2024-09-13 16:27:03,811 INFO worker.py:1616 -- Calling ray.init() again after it has already been called.
+---------------------------------------------------------------+
| Configuration for experiment     IMPALA_2024-09-13_16-27-03   |
+---------------------------------------------------------------+
| Search algorithm                 BasicVariantGenerator        |
| Scheduler                        FIFOScheduler                |
| Number of trials                 1                            |
+---------------------------------------------------------------+

View detailed results here: /root/ray_results/IMPALA_2024-09-13_16-27-03
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-27-03/IMPALA_2024-09-13_16-27-03/driver_artifacts`

Trial status: 1 PENDING
Current time: 2024-09-13 16:27:03. Total running time: 0s
Logical resource usage: 3.0/12 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:L4)
+----------------------------------------------------+
| Trial name                                status   |
+----------------------------------------------------+
| IMPALA_CustomCartPoleEnv-v0_02b18_00000   PENDING  |
+----------------------------------------------------+
(RolloutWorker pid=4547) 2024-09-13 16:27:18,888    WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.recurrent_net.RecurrentNetwork` has been deprecated. This will raise an error in the future!

Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 started with configuration:
+-------------------------------------------------------------------------------+
| Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 config                          |
+-------------------------------------------------------------------------------+
| env                                                      CustomCartPoleEnv-v0 |
| framework                                                               torch |
| model/lstm_cell_size                                                      256 |
| model/max_seq_len                                                         256 |
| model/use_lstm                                                           True |
| num_envs_per_env_runner                                                     7 |
| num_workers                                                                 2 |
| remote_worker_envs                                                       True |
| rollout_fragment_length                                                   256 |
| train_batch_size                                                         2560 |
+-------------------------------------------------------------------------------+
(IMPALA pid=4461) Trainable.setup took 16.618 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
(IMPALA pid=4461) Install gputil for GPU system monitoring.
(IMPALA pid=4461) 2024-09-13 16:27:23,700   WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.recurrent_net.RecurrentNetwork` has been deprecated. This will raise an error in the future!

Trial status: 1 RUNNING
Current time: 2024-09-13 16:27:33. Total running time: 30s
Logical resource usage: 3.0/12 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:L4)
+----------------------------------------------------+
| Trial name                                status   |
+----------------------------------------------------+
| IMPALA_CustomCartPoleEnv-v0_02b18_00000   RUNNING  |
+----------------------------------------------------+
(IMPALA pid=4461) /usr/local/lib/python3.10/dist-packages/ray/rllib/utils/metrics/window_stat.py:55: RuntimeWarning: Mean of empty slice
(IMPALA pid=4461)   return float(np.nanmean(self.items[: self.count]))
(IMPALA pid=4461) /usr/local/lib/python3.10/dist-packages/numpy/lib/nanfunctions.py:1879: RuntimeWarning: Degrees of freedom <= 0 for slice.
(IMPALA pid=4461)   var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,

Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 finished iteration 1 at 2024-09-13 16:27:41. Total running time: 37s
+------------------------------------------------------------------+
| Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 result             |
+------------------------------------------------------------------+
| env_runners/episode_len_mean                             23.8715 |
| env_runners/episode_return_mean                          23.8715 |
| num_env_steps_sampled_lifetime                                 0 |
+------------------------------------------------------------------+

Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 finished iteration 2 at 2024-09-13 16:27:55. Total running time: 51s
+-----------------------------------------------------------------+
| Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 result            |
+-----------------------------------------------------------------+
| env_runners/episode_len_mean                             30.178 |
| env_runners/episode_return_mean                          30.178 |
| num_env_steps_sampled_lifetime                            14336 |
+-----------------------------------------------------------------+

Trial status: 1 RUNNING
Current time: 2024-09-13 16:28:03. Total running time: 1min 0s
Logical resource usage: 3.0/12 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:L4)
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                                status       iter     total time (s)      ts     num_healthy_workers     ...async_sample_reqs     ...e_worker_restarts     ...ent_steps_sampled |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| IMPALA_CustomCartPoleEnv-v0_02b18_00000   RUNNING         2            27.5407   14336                       2                        4                        0                    14336 |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 finished iteration 3 at 2024-09-13 16:28:09. Total running time: 1min 5s
+------------------------------------------------------------------+
| Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 result             |
+------------------------------------------------------------------+
| env_runners/episode_len_mean                             66.3864 |
| env_runners/episode_return_mean                          66.3864 |
| num_env_steps_sampled_lifetime                             28672 |
+------------------------------------------------------------------+
2024-09-13 16:28:12,681 ERROR tune_controller.py:1331 -- Trial task failed for trial IMPALA_CustomCartPoleEnv-v0_02b18_00000
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::IMPALA.train() (pid=4461, ip=172.28.0.12, actor_id=f8e166a52a944dd387f50c9c01000000, repr=IMPALA)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py", line 951, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py", line 3600, in _run_one_training_iteration
    training_step_results = self.training_step()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala.py", line 653, in training_step
    return self._training_step_old_api_stack()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala.py", line 1020, in _training_step_old_api_stack
    raise RuntimeError("The learner thread died while training!")
RuntimeError: The learner thread died while training!
2024-09-13 16:28:12,704 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/root/ray_results/IMPALA_2024-09-13_16-27-03' in 0.0144s.
(IMPALA pid=4461) Exception in thread Thread-1:
(IMPALA pid=4461) Traceback (most recent call last):
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1348, in _worker
(IMPALA pid=4461)     self.loss(model, self.dist_class, sample_batch)
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala_torch_policy.py", line 333, in loss
(IMPALA pid=4461)     loss = VTraceLoss(
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala_torch_policy.py", line 115, in __init__
(IMPALA pid=4461)     actions_logp * self.vtrace_returns.pg_advantages.to(device) * valid_mask
(IMPALA pid=4461) RuntimeError: The size of tensor a (159) must match the size of tensor b (139) at non-singleton dimension 0
(IMPALA pid=4461) 
(IMPALA pid=4461) The above exception was the direct cause of the following exception:
(IMPALA pid=4461) 
(IMPALA pid=4461) Traceback (most recent call last):
(IMPALA pid=4461)   File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
(IMPALA pid=4461)     self.run()
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/execution/learner_thread.py", line 76, in run
(IMPALA pid=4461)     self.step()
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/execution/multi_gpu_learner_thread.py", line 168, in step
(IMPALA pid=4461)     default_policy_results = policy.learn_on_loaded_batch(
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 866, in learn_on_loaded_batch
(IMPALA pid=4461)     tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1433, in _multi_gpu_parallel_grad_calc
(IMPALA pid=4461)     raise last_result[0] from last_result[1]
(IMPALA pid=4461) ValueError: The size of tensor a (159) must match the size of tensor b (139) at non-singleton dimension 0
(IMPALA pid=4461)  tracebackTraceback (most recent call last):
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1348, in _worker
(IMPALA pid=4461)     self.loss(model, self.dist_class, sample_batch)
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala_torch_policy.py", line 333, in loss
(IMPALA pid=4461)     loss = VTraceLoss(
(IMPALA pid=4461)   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/impala/impala_torch_policy.py", line 115, in __init__
(IMPALA pid=4461)     actions_logp * self.vtrace_returns.pg_advantages.to(device) * valid_mask
(IMPALA pid=4461) RuntimeError: The size of tensor a (159) must match the size of tensor b (139) at non-singleton dimension 0
(IMPALA pid=4461) 
(IMPALA pid=4461) In tower 0 on device cuda:0

Trial IMPALA_CustomCartPoleEnv-v0_02b18_00000 errored after 3 iterations at 2024-09-13 16:28:12. Total running time: 1min 8s
Error file: /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-27-03/IMPALA_2024-09-13_16-27-03/driver_artifacts/IMPALA_CustomCartPoleEnv-v0_02b18_00000_0_2024-09-13_16-27-03/error.txt

Trial status: 1 ERROR
Current time: 2024-09-13 16:28:12. Total running time: 1min 8s
Logical resource usage: 3.0/12 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:L4)
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                                status       iter     total time (s)      ts     num_healthy_workers     ...async_sample_reqs     ...e_worker_restarts     ...ent_steps_sampled |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| IMPALA_CustomCartPoleEnv-v0_02b18_00000   ERROR           3            41.2119   28672                       2                        4                        0                    28672 |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Number of errored trials: 1
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                                  # failures   error file                                                                                                                                                                                         |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| IMPALA_CustomCartPoleEnv-v0_02b18_00000              1   /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-27-03/IMPALA_2024-09-13_16-27-03/driver_artifacts/IMPALA_CustomCartPoleEnv-v0_02b18_00000_0_2024-09-13_16-27-03/error.txt |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

---------------------------------------------------------------------------
TuneError                                 Traceback (most recent call last)
[<ipython-input-4-a590513b3dc5>](https://localhost:8080/#) in <cell line: 9>()
     24         "train_batch_size": 2560,
     25     }
---> 26     tune.run(
     27         "IMPALA",
     28         config=config,

[/usr/local/lib/python3.10/dist-packages/ray/tune/tune.py](https://localhost:8080/#) in run(run_or_experiment, name, metric, mode, stop, time_budget_s, config, resources_per_trial, num_samples, storage_path, storage_filesystem, search_alg, scheduler, checkpoint_config, verbose, progress_reporter, log_to_file, trial_name_creator, trial_dirname_creator, sync_config, export_formats, max_failures, fail_fast, restore, resume, resume_config, reuse_actors, raise_on_failed_trial, callbacks, max_concurrent_trials, keep_checkpoints_num, checkpoint_score_attr, checkpoint_freq, checkpoint_at_end, chdir_to_trial_dir, local_dir, _remote, _remote_string_queue, _entrypoint)
   1033     if incomplete_trials:
   1034         if raise_on_failed_trial and not experiment_interrupted_event.is_set():
-> 1035             raise TuneError("Trials did not complete", incomplete_trials)
   1036         else:
   1037             logger.error("Trials did not complete: %s", incomplete_trials)

TuneError: ('Trials did not complete', [IMPALA_CustomCartPoleEnv-v0_02b18_00000])
2024-09-13 16:31:36,729 INFO worker.py:1616 -- Calling ray.init() again after it has already been called.
+------------------------------------------------------------+
| Configuration for experiment     PPO_2024-09-13_16-31-36   |
+------------------------------------------------------------+
| Search algorithm                 BasicVariantGenerator     |
| Scheduler                        FIFOScheduler             |
| Number of trials                 1                         |
+------------------------------------------------------------+

View detailed results here: /root/ray_results/PPO_2024-09-13_16-31-36
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-31-36/PPO_2024-09-13_16-31-36/driver_artifacts`

Trial status: 1 PENDING
Current time: 2024-09-13 16:31:36. Total running time: 0s
Logical resource usage: 3.0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:L4)
+-------------------------------------------------+
| Trial name                             status   |
+-------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000   PENDING  |
+-------------------------------------------------+
(RolloutWorker pid=6905) 2024-09-13 16:31:51,615    WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.recurrent_net.RecurrentNetwork` has been deprecated. This will raise an error in the future!

Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 started with configuration:
+----------------------------------------------------------------------------+
| Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 config                          |
+----------------------------------------------------------------------------+
| env                                                   CustomCartPoleEnv-v0 |
| framework                                                            torch |
| model/lstm_cell_size                                                   256 |
| model/max_seq_len                                                      256 |
| model/use_lstm                                                        True |
| num_envs_per_env_runner                                                  7 |
| num_workers                                                              2 |
| remote_worker_envs                                                    True |
| rollout_fragment_length                                                256 |
| train_batch_size                                                      3560 |
+----------------------------------------------------------------------------+
(PPO pid=6820) Install gputil for GPU system monitoring.

Trial status: 1 RUNNING
Current time: 2024-09-13 16:32:06. Total running time: 30s
Logical resource usage: 3.0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:L4)
+-------------------------------------------------+
| Trial name                             status   |
+-------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000   RUNNING  |
+-------------------------------------------------+
(PPO pid=6820) 2024-09-13 16:32:10,338  WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.execution.train_ops.multi_gpu_train_one_step` has been deprecated. This will raise an error in the future!
(PPO pid=6820) 2024-09-13 16:31:53,317  WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.recurrent_net.RecurrentNetwork` has been deprecated. This will raise an error in the future!
Trial status: 1 RUNNING
Current time: 2024-09-13 16:32:37. Total running time: 1min 0s
Logical resource usage: 3.0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:L4)
+-------------------------------------------------+
| Trial name                             status   |
+-------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000   RUNNING  |
+-------------------------------------------------+

Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 finished iteration 1 at 2024-09-13 16:32:40. Total running time: 1min 3s
+---------------------------------------------------------------+
| Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 result             |
+---------------------------------------------------------------+
| env_runners/episode_len_mean                          21.3706 |
| env_runners/episode_return_mean                       21.3706 |
| num_env_steps_sampled_lifetime                           3584 |
+---------------------------------------------------------------+

Trial status: 1 RUNNING
Current time: 2024-09-13 16:33:07. Total running time: 1min 30s
Logical resource usage: 3.0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:L4)
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                             status       iter     total time (s)     ts     num_healthy_workers     ...async_sample_reqs     ...e_worker_restarts     ...ent_steps_sampled |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000   RUNNING         1            47.1346   3584                       2                        0                        0                     3584 |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 finished iteration 2 at 2024-09-13 16:33:24. Total running time: 1min 48s
+-------------------------------------------------------------+
| Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 result           |
+-------------------------------------------------------------+
| env_runners/episode_len_mean                          38.51 |
| env_runners/episode_return_mean                       38.51 |
| num_env_steps_sampled_lifetime                         7168 |
+-------------------------------------------------------------+
2024-09-13 16:33:29,096 ERROR tune_controller.py:1331 -- Trial task failed for trial PPO_CustomCartPoleEnv-v0_a55d8_00000
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): ray::PPO.train() (pid=6820, ip=172.28.0.12, actor_id=60766abbbaf5181ce01349c601000000, repr=PPO)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 84, in loss
    logits, state = model(train_batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/modelv2.py", line 256, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/fcnet.py", line 144, in forward
    self._last_flat_in = obs.reshape(obs.shape[0], -1)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

The above exception was the direct cause of the following exception:

ray::PPO.train() (pid=6820, ip=172.28.0.12, actor_id=60766abbbaf5181ce01349c601000000, repr=PPO)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py", line 951, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py", line 3600, in _run_one_training_iteration
    training_step_results = self.training_step()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo.py", line 434, in training_step
    return self._training_step_old_and_hybrid_api_stacks()
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo.py", line 576, in _training_step_old_and_hybrid_api_stacks
    train_results = multi_gpu_train_one_step(self, train_batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/execution/train_ops.py", line 176, in multi_gpu_train_one_step
    results = policy.learn_on_loaded_batch(
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 838, in learn_on_loaded_batch
    return self.learn_on_batch(batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 715, in learn_on_batch
    grads, fetches = self.compute_gradients(postprocessed_batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 933, in compute_gradients
    tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1433, in _multi_gpu_parallel_grad_calc
    raise last_result[0] from last_result[1]
ValueError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous
 tracebackTraceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1348, in _worker
    self.loss(model, self.dist_class, sample_batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 84, in loss
    logits, state = model(train_batch)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/modelv2.py", line 256, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/fcnet.py", line 144, in forward
    self._last_flat_in = obs.reshape(obs.shape[0], -1)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

In tower 0 on device cpu
2024-09-13 16:33:29,113 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/root/ray_results/PPO_2024-09-13_16-31-36' in 0.0115s.

Trial PPO_CustomCartPoleEnv-v0_a55d8_00000 errored after 2 iterations at 2024-09-13 16:33:29. Total running time: 1min 52s
Error file: /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-31-36/PPO_2024-09-13_16-31-36/driver_artifacts/PPO_CustomCartPoleEnv-v0_a55d8_00000_0_2024-09-13_16-31-36/error.txt

Trial status: 1 ERROR
Current time: 2024-09-13 16:33:29. Total running time: 1min 52s
Logical resource usage: 3.0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:L4)
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                             status       iter     total time (s)     ts     num_healthy_workers     ...async_sample_reqs     ...e_worker_restarts     ...ent_steps_sampled |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000   ERROR           2            91.3006   7168                       2                        0                        0                     7168 |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Number of errored trials: 1
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                               # failures   error file                                                                                                                                                                                   |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| PPO_CustomCartPoleEnv-v0_a55d8_00000              1   /tmp/ray/session_2024-09-13_16-22-47_988205_1242/artifacts/2024-09-13_16-31-36/PPO_2024-09-13_16-31-36/driver_artifacts/PPO_CustomCartPoleEnv-v0_a55d8_00000_0_2024-09-13_16-31-36/error.txt |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

---------------------------------------------------------------------------
TuneError                                 Traceback (most recent call last)
[<ipython-input-7-59c18aacbe50>](https://localhost:8080/#) in <cell line: 9>()
     24         "train_batch_size": 3560,
     25     }
---> 26     tune.run(
     27         "PPO",
     28         config=config,

[/usr/local/lib/python3.10/dist-packages/ray/tune/tune.py](https://localhost:8080/#) in run(run_or_experiment, name, metric, mode, stop, time_budget_s, config, resources_per_trial, num_samples, storage_path, storage_filesystem, search_alg, scheduler, checkpoint_config, verbose, progress_reporter, log_to_file, trial_name_creator, trial_dirname_creator, sync_config, export_formats, max_failures, fail_fast, restore, resume, resume_config, reuse_actors, raise_on_failed_trial, callbacks, max_concurrent_trials, keep_checkpoints_num, checkpoint_score_attr, checkpoint_freq, checkpoint_at_end, chdir_to_trial_dir, local_dir, _remote, _remote_string_queue, _entrypoint)
   1033     if incomplete_trials:
   1034         if raise_on_failed_trial and not experiment_interrupted_event.is_set():
-> 1035             raise TuneError("Trials did not complete", incomplete_trials)
   1036         else:
   1037             logger.error("Trials did not complete: %s", incomplete_trials)

TuneError: ('Trials did not complete', [PPO_CustomCartPoleEnv-v0_a55d8_00000])

Versions / Dependencies

absl-py==1.4.0 accelerate==0.34.2 aiohappyeyeballs==2.4.0 aiohttp==3.10.5 aiohttp-cors==0.7.0 aiosignal==1.3.1 alabaster==0.7.16 albucore==0.0.14 albumentations==1.4.14 altair==4.2.2 annotated-types==0.7.0 anyio==3.7.1 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 array_record==0.5.1 arviz==0.18.0 asn1crypto==1.5.1 astropy==6.1.3 astropy-iers-data==0.2024.9.2.0.33.23 astunparse==1.6.3 async-timeout==4.0.3 atpublic==4.1.0 attrs==24.2.0 audioread==3.0.1 autograd==1.7.0 babel==2.16.0 backcall==0.2.0 beautifulsoup4==4.12.3 bidict==0.23.1 bigframes==1.15.0 bigquery-magics==0.2.0 bleach==6.1.0 blinker==1.4 blis==0.7.11 blosc2==2.0.0 bokeh==3.4.3 bqplot==0.12.43 branca==0.7.2 build==1.2.2 CacheControl==0.14.0 cachetools==5.5.0 catalogue==2.0.10 certifi==2024.8.30 cffi==1.17.1 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.86 clarabel==0.9.0 click==8.1.7 click-plugins==1.1.1 cligj==0.7.2 cloudpathlib==0.19.0 cloudpickle==2.2.1 cmake==3.30.3 cmdstanpy==1.2.4 colorcet==3.1.0 colorful==0.5.6 colorlover==0.3.0 colour==0.1.5 community==1.0.0b1 confection==0.1.5 cons==0.4.6 contextlib2==21.6.0 contourpy==1.3.0 cryptography==43.0.1 cuda-python==12.2.1 cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694 cufflinks==0.17.3 cupy-cuda12x==12.2.0 cvxopt==1.3.2 cvxpy==1.5.3 cycler==0.12.1 cymem==2.0.8 Cython==3.0.11 dask==2024.7.1 datascience==0.17.6 db-dtypes==1.3.0 dbus-python==1.2.18 debugpy==1.6.6 decorator==4.4.2 defusedxml==0.7.1 distlib==0.3.8 distributed==2024.7.1 distro==1.7.0 dlib==19.24.2 dm-tree==0.1.8 docstring_parser==0.16 docutils==0.18.1 dopamine_rl==4.0.9 duckdb==0.10.3 earthengine-api==1.0.0 easydict==1.13 ecos==2.0.14 editdistance==0.8.1 eerepr==0.0.4 einops==0.8.0 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889 entrypoints==0.4 et-xmlfile==1.1.0 etils==1.9.4 etuples==0.3.9 eval_type_backport==0.2.0 exceptiongroup==1.2.2 Farama-Notifications==0.0.4 fastai==2.7.17 fastcore==1.7.5 fastdownload==0.0.7 fastjsonschema==2.20.0 fastprogress==1.0.3 fastrlock==0.8.2 filelock==3.16.0 fiona==1.10.0 firebase-admin==6.5.0 Flask==2.2.5 flatbuffers==24.3.25 flax==0.8.4 folium==0.17.0 fonttools==4.53.1 frozendict==2.4.4 frozenlist==1.4.1 fsspec==2024.6.1 future==1.0.0 gast==0.6.0 gcsfs==2024.6.1 GDAL==3.6.4 gdown==5.1.0 geemap==0.34.1 gensim==4.3.3 geocoder==1.38.1 geographiclib==2.0 geopandas==0.14.4 geopy==2.4.1 gin-config==0.5.0 glob2==0.7 google==2.0.3 google-ai-generativelanguage==0.6.6 google-api-core==2.19.2 google-api-python-client==2.137.0 google-auth==2.27.0 google-auth-httplib2==0.2.0 google-auth-oauthlib==1.2.1 google-cloud-aiplatform==1.65.0 google-cloud-bigquery==3.25.0 google-cloud-bigquery-connection==1.15.5 google-cloud-bigquery-storage==2.26.0 google-cloud-bigtable==2.26.0 google-cloud-core==2.4.1 google-cloud-datastore==2.19.0 google-cloud-firestore==2.16.1 google-cloud-functions==1.16.5 google-cloud-iam==2.15.2 google-cloud-language==2.13.4 google-cloud-pubsub==2.23.1 google-cloud-resource-manager==1.12.5 google-cloud-storage==2.8.0 google-cloud-translate==3.15.5 google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=9eb5f50246c4e1b93fdc0d5618acd563b5cbf1593986541fe1c6697138e0fa26 google-crc32c==1.6.0 google-generativeai==0.7.2 google-pasta==0.2.0 google-resumable-media==2.7.2 googleapis-common-protos==1.65.0 googledrivedownloader==0.4 graphviz==0.20.3 greenlet==3.1.0 grpc-google-iam-v1==0.13.1 grpcio==1.64.1 grpcio-status==1.48.2 gspread==6.0.2 gspread-dataframe==3.3.1 gym==0.25.2 gym-notices==0.0.8 gymnasium==0.28.1 h5netcdf==1.3.0 h5py==3.11.0 holidays==0.56 holoviews==1.18.3 html5lib==1.1 httpimport==1.3.1 httplib2==0.22.0 huggingface-hub==0.24.6 humanize==4.10.0 hyperopt==0.2.7 ibis-framework==8.0.0 idna==3.8 imageio==2.34.2 imageio-ffmpeg==0.5.1 imagesize==1.4.1 imbalanced-learn==0.12.3 imgaug==0.4.0 immutabledict==4.2.0 importlib_metadata==8.4.0 importlib_resources==6.4.5 imutils==0.5.4 inflect==7.3.1 iniconfig==2.0.0 intel-cmplr-lib-ur==2024.2.1 intel-openmp==2024.2.1 ipyevents==2.0.2 ipyfilechooser==0.6.0 ipykernel==5.5.6 ipyleaflet==0.18.2 ipyparallel==8.8.0 ipython==7.34.0 ipython-genutils==0.2.0 ipython-sql==0.5.0 ipytree==0.2.2 ipywidgets==7.7.1 itsdangerous==2.2.0 jax==0.4.26 jax-jumpy==1.0.0 jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750 jeepney==0.7.1 jellyfish==1.1.0 jieba==0.42.1 Jinja2==3.1.4 joblib==1.4.2 jsonpickle==3.3.0 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 jupyter-client==6.1.12 jupyter-console==6.1.0 jupyter-server==1.24.0 jupyter_core==5.7.2 jupyterlab_pygments==0.3.0 jupyterlab_widgets==3.0.13 kaggle==1.6.17 kagglehub==0.2.9 keras==3.4.1 keyring==23.5.0 kiwisolver==1.4.7 langcodes==3.4.0 language_data==1.2.0 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 lazy_loader==0.4 libclang==18.1.1 librosa==0.10.2.post1 lightgbm==4.4.0 linkify-it-py==2.0.3 llvmlite==0.43.0 locket==1.0.0 logical-unification==0.4.6 lxml==4.9.4 lz4==4.3.3 malloy==2024.1091 marisa-trie==1.2.0 Markdown==3.7 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.7.1 matplotlib-inline==0.1.7 matplotlib-venn==0.11.10 mdit-py-plugins==0.4.2 mdurl==0.1.2 memray==1.14.0 miniKanren==1.0.3 missingno==0.5.2 mistune==0.8.4 mizani==0.9.3 mkl==2024.2.1 ml-dtypes==0.4.0 mlxtend==0.23.1 more-itertools==10.3.0 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.8 multidict==6.1.0 multipledispatch==1.0.0 multitasking==0.0.11 murmurhash==1.0.10 music21==9.1.0 namex==0.0.8 natsort==8.4.0 nbclassic==1.1.0 nbclient==0.10.0 nbconvert==6.5.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.3 nibabel==5.0.1 nltk==3.8.1 notebook==6.5.5 notebook_shim==0.2.4 numba==0.60.0 numexpr==2.10.1 numpy==1.26.4 nvidia-nccl-cu12==2.22.3 nvtx==0.2.10 oauth2client==4.1.3 oauthlib==3.2.2 opencensus==0.11.4 opencensus-context==0.1.3 opencv-contrib-python==4.10.0.84 opencv-python==4.10.0.84 opencv-python-headless==4.10.0.84 openpyxl==3.1.5 opt-einsum==3.3.0 optax==0.2.2 optree==0.12.1 orbax-checkpoint==0.6.1 osqp==0.6.7.post0 packaging==24.1 pandas==2.1.4 pandas-datareader==0.10.0 pandas-gbq==0.23.1 pandas-stubs==2.1.4.231227 pandocfilters==1.5.1 panel==1.4.5 param==2.1.1 parso==0.8.4 parsy==2.1 partd==1.4.2 pathlib==1.0.1 patsy==0.5.6 peewee==3.17.6 pexpect==4.9.0 pickleshare==0.7.5 Pillow==9.4.0 pip-tools==7.4.1 platformdirs==4.3.2 plotly==5.15.0 plotnine==0.12.4 pluggy==1.5.0 polars==0.20.2 pooch==1.8.2 portpicker==1.5.2 prefetch_generator==1.0.3 preshed==3.0.9 prettytable==3.11.0 proglog==0.1.10 progressbar2==4.2.0 prometheus_client==0.20.0 promise==2.3 prompt_toolkit==3.0.47 prophet==1.1.5 proto-plus==1.24.0 protobuf==3.20.3 psutil==5.9.5 psycopg2==2.9.9 ptyprocess==0.7.0 py-cpuinfo==9.0.0 py-spy==0.3.14 py4j==0.10.9.7 pyarrow==14.0.2 pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycocotools==2.0.8 pycparser==2.22 pydantic==2.9.1 pydantic_core==2.23.3 pydata-google-auth==1.8.2 pydot==1.4.2 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 PyDrive2==1.6.3 pyerfa==2.0.1.4 pygame==2.6.0 Pygments==2.16.1 PyGObject==3.42.1 PyJWT==2.9.0 pymc==5.10.4 pymystem3==0.2.0 pynvjitlink-cu12==0.3.0 PyOpenGL==3.1.7 pyOpenSSL==24.2.1 pyparsing==3.1.4 pyperclip==1.9.0 pyproj==3.6.1 pyproject_hooks==1.1.0 pyshp==2.3.1 PySocks==1.7.1 pytensor==2.18.6 pytest==7.4.4 python-apt==2.4.0 python-box==7.2.0 python-dateutil==2.8.2 python-louvain==0.16 python-slugify==8.0.4 python-utils==3.8.2 pytz==2024.1 pyviz_comms==3.0.3 PyYAML==6.0.2 pyzmq==24.0.1 qdldl==0.1.7.post4 ratelim==0.1.6 ray==2.35.0 referencing==0.35.1 regex==2024.5.15 requests==2.32.3 requests-oauthlib==1.3.1 requirements-parser==0.9.0 rich==13.8.1 rmm-cu12==24.4.0 rpds-py==0.20.0 rpy2==3.4.2 rsa==4.9 safetensors==0.4.5 scikit-image==0.23.2 scikit-learn==1.3.2 scipy==1.13.1 scooby==0.10.0 scs==3.2.7 seaborn==0.13.1 SecretStorage==3.3.1 Send2Trash==1.8.3 sentencepiece==0.1.99 shapely==2.0.6 shellingham==1.5.4 simple-parsing==0.1.6 six==1.16.0 sklearn-pandas==2.2.0 smart-open==7.0.4 sniffio==1.3.1 snowballstemmer==2.2.0 snowflake-connector-python==3.12.1 sortedcontainers==2.4.0 soundfile==0.12.1 soupsieve==2.6 soxr==0.5.0.post1 spacy==3.7.6 spacy-legacy==3.0.12 spacy-loggers==1.0.5 Sphinx==5.0.2 sphinxcontrib-applehelp==2.0.0 sphinxcontrib-devhelp==2.0.0 sphinxcontrib-htmlhelp==2.1.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==2.0.0 sphinxcontrib-serializinghtml==2.0.0 SQLAlchemy==2.0.34 sqlglot==20.11.0 sqlparse==0.5.1 srsly==2.4.8 stanio==0.5.1 statsmodels==0.14.2 StrEnum==0.4.15 sympy==1.13.2 tables==3.8.0 tabulate==0.9.0 tbb==2021.13.1 tblib==3.0.0 tenacity==9.0.0 tensorboard==2.17.0 tensorboard-data-server==0.7.2 tensorboardX==2.6.2.2 tensorflow==2.17.0 tensorflow-datasets==4.9.6 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.37.1 tensorflow-metadata==1.15.0 tensorflow-probability==0.24.0 tensorstore==0.1.65 termcolor==2.4.0 terminado==0.18.1 text-unidecode==1.3 textblob==0.17.1 textual==0.79.1 tf-slim==1.1.0 tf_keras==2.17.0 thinc==8.2.5 threadpoolctl==3.5.0 tifffile==2024.8.30 tinycss2==1.3.0 tokenizers==0.19.1 toml==0.10.2 tomli==2.0.1 tomlkit==0.13.2 toolz==0.12.1 torch @ https://download.pytorch.org/whl/cu121_full/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=6f3aabcec8b7728943d22bec2d8017b1bd2d69cd903eefb7dd3a373e4f779c40 torchaudio @ https://download.pytorch.org/whl/cu121_full/torchaudio-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=8bc4c22a701f4567a15cc98ff48c392147583b9ec4157d147025f297bf305acc torchsummary==1.5.1 torchvision @ https://download.pytorch.org/whl/cu121_full/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=53c7ee4a98c8464ec964a6ab3804f5925b19bac698ef728f148bffebe27a9081 tornado==6.3.3 tqdm==4.66.5 traitlets==5.7.1 traittypes==0.2.1 transformers==4.44.2 tweepy==4.14.0 typeguard==4.3.0 typer==0.12.5 types-pytz==2024.1.0.20240417 types-setuptools==74.1.0.20240907 typing_extensions==4.12.2 tzdata==2024.1 tzlocal==5.2 uc-micro-py==1.0.3 uritemplate==4.1.1 urllib3==2.0.7 vega-datasets==0.9.0 virtualenv==20.26.4 wadllib==1.3.6 wasabi==1.1.3 wcwidth==0.2.13 weasel==0.4.1 webcolors==24.8.0 webencodings==0.5.1 websocket-client==1.8.0 Werkzeug==3.0.4 widgetsnbextension==3.6.9 wordcloud==1.9.3 wrapt==1.16.0 xarray==2024.6.0 xarray-einstats==0.7.0 xgboost==2.1.1 xlrd==2.0.1 xyzservices==2024.9.0 yarl==1.11.1 yellowbrick==1.5 yfinance==0.2.43 zict==3.0.0 zipp==3.20.1

Reproduction script

import gymnasium as gym
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
import numpy as np

import ray
from ray import tune
from ray.tune.registry import register_env

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)
    register_env("CustomCartPoleEnv-v0", CartPoleEnv)
    config = {
        "env": "CustomCartPoleEnv-v0",
        "framework": "torch",
        "model": {
            "use_lstm": True,
            "lstm_cell_size": 256,
            "max_seq_len": 256,
        },
        "num_workers": 2,
        "num_envs_per_env_runner": 7,
        "remote_worker_envs": True,
        "rollout_fragment_length": 256,
        "train_batch_size": 2560,
    }
    tune.run(
        "IMPALA",
        config=config,
        stop={"training_iteration": 50},
    )
import gymnasium as gym
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
import numpy as np

import ray
from ray import tune
from ray.tune.registry import register_env

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)
    register_env("CustomCartPoleEnv-v0", CartPoleEnv)
    config = {
        "env": "CustomCartPoleEnv-v0",
        "framework": "torch",
        "model": {
            "use_lstm": True,
            "lstm_cell_size": 256,
            "max_seq_len": 256,
        },
        "num_workers": 2,
        "num_envs_per_env_runner": 7,
        "remote_worker_envs": True,
        "rollout_fragment_length": 256,
        "train_batch_size": 3560,
    }
    tune.run(
        "PPO",
        config=config,
        stop={"training_iteration": 50},
    )

Issue Severity

High: It blocks me from completing my task.

simonsays1980 commented 2 weeks ago

@klae01 thanks for raising this issue. Could you try to use our new stack implementation via the following code (feel free to change parameters where needed):

from ray.rllib.algorithms.impala import IMPALAConfig
from ray.rllib.utils.metrics import (
    ENV_RUNNER_RESULTS,
    EPISODE_RETURN_MEAN,
    NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.rllib.utils.test_utils import add_rllib_example_script_args
from ray import tune

parser = add_rllib_example_script_args()
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.
args = parser.parse_args()

config = (
    IMPALAConfig()
    # Enable new API stack and use EnvRunner.
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .resources(
        num_gpus=0,
    )
    .environment("CartPole-v1")
    .training(
        train_batch_size_per_learner=500,
        grad_clip=40.0,
        grad_clip_by="global_norm",
        lr=0.0005 * ((args.num_gpus or 1) ** 0.5),
        vf_loss_coeff=0.05,
        entropy_coeff=0.0,
    )
    .learners(
        num_gpus_per_learner=0,
    )
    .rl_module(
        model_config_dict={
            "vf_share_layers": True,
            "uses_new_env_runners": True,
            "use_lstm": True,
            "lstm_cell_size": 256,
            "max_seq_len": 256,
        },
    )
)

stop = {
    f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 450.0,
    NUM_ENV_STEPS_SAMPLED_LIFETIME: 2000000,
}

if __name__ == "__main__":
    tune.run(
        "IMPALA",
        config=config,
        stop={"training_iteration": 50},
    )