rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.84k stars 309 forks source link

pearl error on custom environment. #2294

Closed tianyma closed 3 years ago

tianyma commented 3 years ago

Hi, I run pearl on my custom environment, but an error occurs, can you help me? I currently use the master branch.

Each action has shape (1,) but must match the action_space Discrete(4)
  File "/lustre/S/matianyun/garage/src/garage/_dtypes.py", line 1055, in check_timestep_batch
    f'Each {field[:-1]} has shape {value[0].shape} '
  File "/lustre/S/matianyun/garage/src/garage/_dtypes.py", line 589, in __init__
    ignored_fields={'next_observations', 'episode_infos'})
  File "/lustre/S/matianyun/garage/src/garage/sampler/default_worker.py", line 174, in collect_episode
    lengths=np.asarray(lengths, dtype='i'))
  File "/lustre/S/matianyun/garage/src/garage/torch/algos/pearl.py", line 778, in rollout
    return self.collect_episode()
  File "/lustre/S/matianyun/garage/src/garage/sampler/local_sampler.py", line 160, in obtain_samples
    batch = worker.rollout()
  File "/lustre/S/matianyun/garage/src/garage/trainer.py", line 221, in obtain_episodes
    env_update=env_update)
  File "/lustre/S/matianyun/garage/src/garage/trainer.py", line 254, in obtain_samples
    eps = self.obtain_episodes(itr, batch_size, agent_update, env_update)
  File "/lustre/S/matianyun/garage/src/garage/torch/algos/pearl.py", line 440, in _obtain_samples
    self._env[self._task_idx])
  File "/lustre/S/matianyun/garage/src/garage/torch/algos/pearl.py", line 284, in train
    self._num_initial_steps, np.inf)
  File "/lustre/S/matianyun/garage/src/garage/trainer.py", line 396, in train
    average_return = self._algo.train(self)
  File "/lustre/S/matianyun/JiHuang/agent/garage/src/garage/examples/torch/jihuang_2d_nav_pearl.py", line 165, in pearl_jihuang_2d_nav
    trainer.train(n_epochs=num_epochs, batch_size=batch_size)
  File "/lustre/S/matianyun/garage/src/garage/experiment/experiment.py", line 369, in __call__
    result = self.function(ctxt, **kwargs)
  File "/lustre/S/matianyun/JiHuang/agent/garage/src/garage/examples/torch/jihuang_2d_nav_pearl.py", line 168, in <module>
    pearl_jihuang_2d_nav()
krzentner commented 3 years ago

Yeah, this looks like flattened and unflattened discrete actions are getting mixed up. What policies are you using inside ContextConditionedPolicy? It probably needs to output a one-hot instead of a discrete action (or we need to put a fix somewhere in the core datatypes).

tianyma commented 3 years ago

thank you for your reply, I found I have to unflatten the action then I can get the discrete number.