adityab / CrossQ

Official code release for "CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity"
http://aditya.bhatts.org/CrossQ
Other
54 stars 3 forks source link

Some tasks from deepmind/* not working #6

Open JankowskiChristopher opened 5 months ago

JankowskiChristopher commented 5 months ago

Hello, I am trying to benchmark your code on more tasks from deepmind/* but they are not working. There seems to be a bug in the prepare_obs function in sbx/common/policies.py. I attach stack trace below:

  1. Task deepmind/quadruped-run

    Traceback (most recent call last):
    File "/home/src/crossq/train.py", line 264, in <module>
    model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=callback_list)
    File "/home/src/crossq/sbx/sac/sac.py", line 187, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 312, in learn
    rollout = self.collect_rollouts(
              ^^^^^^^^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 552, in collect_rollouts
    if callback.on_step() is False:
       ^^^^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
    File "/home/src/crossq/sbx/sac/actor_critic_evaluation_callback.py", line 355, in _on_step
    episode_rewards, episode_lengths = evaluate_policy(
                                       ^^^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/evaluation.py", line 88, in evaluate_policy
    actions, states = model.predict(
                      ^^^^^^^^^^^^^^
    File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/base_class.py", line 555, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/src/crossq/sbx/common/policies.py", line 62, in predict
    observation, vectorized_env = self.prepare_obs(observation)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/src/crossq/sbx/common/policies.py", line 95, in prepare_obs
    observation = np.concatenate(
                  ^^^^^^^^^^^^^^^
    ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 3 has 1 dimension(s)
  2. Task deepmind/humanoid-walk

Traceback (most recent call last):
  File "/home/src/crossq/train.py", line 264, in <module>
    model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=callback_list)
  File "/home/src/crossq/sbx/sac/sac.py", line 187, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 312, in learn
    rollout = self.collect_rollouts(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 552, in collect_rollouts
    if callback.on_step() is False:
       ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/sac/actor_critic_evaluation_callback.py", line 355, in _on_step
    episode_rewards, episode_lengths = evaluate_policy(
                                       ^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/evaluation.py", line 88, in evaluate_policy
    actions, states = model.predict(
                      ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/base_class.py", line 555, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 62, in predict
    observation, vectorized_env = self.prepare_obs(observation)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 95, in prepare_obs
    observation = np.concatenate(
                  ^^^^^^^^^^^^^^^
ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 2 has 1 dimension(s)
  1. Task deepmind/humanoid-run
Traceback (most recent call last):
  File "/home/src/crossq/train.py", line 264, in <module>
    model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=callback_list)
  File "/home/src/crossq/sbx/sac/sac.py", line 187, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 312, in learn
    rollout = self.collect_rollouts(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 552, in collect_rollouts
    if callback.on_step() is False:
       ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/sac/actor_critic_evaluation_callback.py", line 355, in _on_step
    episode_rewards, episode_lengths = evaluate_policy(
                                       ^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/evaluation.py", line 88, in evaluate_policy
    actions, states = model.predict(
                      ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/base_class.py", line 555, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 62, in predict
    observation, vectorized_env = self.prepare_obs(observation)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 95, in prepare_obs
    observation = np.concatenate(
                  ^^^^^^^^^^^^^^^
ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 2 has 1 dimension(s)
  1. Task deepmind/walker-run
Traceback (most recent call last):
  File "/home/src/crossq/train.py", line 264, in <module>
    model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=callback_list)
  File "/home/src/crossq/sbx/sac/sac.py", line 187, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 312, in learn
    rollout = self.collect_rollouts(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 552, in collect_rollouts
    if callback.on_step() is False:
       ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 208, in _on_step
    continue_training = callback.on_step() and continue_training
                        ^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/callbacks.py", line 104, in on_step
    return self._on_step()
           ^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/sac/actor_critic_evaluation_callback.py", line 355, in _on_step
    episode_rewards, episode_lengths = evaluate_policy(
                                       ^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/evaluation.py", line 88, in evaluate_policy
    actions, states = model.predict(
                      ^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/crossq/lib/python3.11/site-packages/stable_baselines3/common/base_class.py", line 555, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 62, in predict
    observation, vectorized_env = self.prepare_obs(observation)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/src/crossq/sbx/common/policies.py", line 95, in prepare_obs
    observation = np.concatenate(
                  ^^^^^^^^^^^^^^^
numpy.exceptions.AxisError: axis 1 is out of bounds for array of dimension 1

I believe that this is due to different dict formats of observation returned by shimmy. I once had similar problems in another project and fixed them by using the function from TD-MPC2 GitHub repository:

def _obs_to_array(self, obs):
    return np.concatenate([v.flatten() for v in obs.values()])

Maybe you can try this as well and will work better.

JankowskiChristopher commented 5 months ago

deepmind/finger-turn_hard also gives similar error.

JankowskiChristopher commented 4 months ago

@adityab @danielpalen I tried using such concatenation

np.concatenate([v.flatten() for v in obs.values()])

in self.prepare_obs(observation). It works in terms of not giving errors (tested on cheetah_run), but the rewards are much lower. Maybe this is due to the MultiInputPolicy and therefore code needs to be changed in another place too to make it work.