Farama-Foundation / Minari

A standard format for offline reinforcement learning datasets, with popular reference datasets and related utilities
https://minari.farama.org
Other
275 stars 42 forks source link

[Bug Report] Error when attempting to create a dataset for MuJoco Hopper-v4 and v5 #246

Open jamartinh opened 2 weeks ago

jamartinh commented 2 weeks ago

Describe the bug Error when attempting to create a dataset for MuJoco Hopper-v4 and v5

If I do not use the minari.DataCollector wrapper the code works all ok and trains successfully

Code example

env = gym.make("Hopper-v5", disable_env_checker=True ) env = gym.wrappers.RecordEpisodeStatistics(env) env = minari.DataCollector(env, record_infos=True, observation_space=env.observation_space, action_space=env.action_space)

obs, _ = env.reset() action = env.action_space.sample() next_obs, rewards, terminations, truncations, infos = env.step(action)

File /data1/deploy/Minari/minari/data_collector/data_collector.py:155, in DataCollector.step(self, action) 153 if not self._record_infos: 154 step_data["info"] = {} --> 155 self._buffer = self._buffer.add_step_data(step_data) 157 if step_data["termination"] or step_data["truncation"]: 158 self._storage.update_episodes([self._buffer])

File /data1/deploy/Minari/minari/data_collector/episode_buffer.py:60, in EpisodeBuffer.add_step_data(self, step_data) 58 infos = jtu.tree_map(lambda x: [x], step_data["info"]) 59 else: ---> 60 infos = jtu.tree_map(_append, step_data["info"], self.infos) 62 self.rewards.append(step_data["reward"]) 63 self.terminations.append(step_data["termination"])

File /data1/conda/envs/python3.12/lib/python3.12/site-packages/jax/_src/tree_util.py:342, in tree_map(f, tree, is_leaf, rest) 340 """Alias of :func:jax.tree.map.""" 341 leaves, treedef = tree_flatten(tree, is_leaf) --> 342 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] 343 return treedef.unflatten(f(xs) for xs in zip(*all_leaves))

ValueError: Dict key mismatch; expected keys: ['reward_ctrl', 'reward_forward', 'reward_survive', 'x_position', 'x_velocity', 'z_distance_from_origin']; dict: {'x_position': np.float64(0.0012007819486893048), 'z_distance_from_origin': np.float64(-0.00042011979112821507)}.

System Info Describe the characteristic of your environment: Gymnasium installed from latest main branch of repo "'version 1.0.0'"

Additional context JypyterLab

Checklist

jamartinh commented 2 weeks ago

it seems that the data_collector is not prepared for "infos" with different keys for env.reset() and env.step()

younik commented 2 weeks ago

it seems that the data_collector is not prepared for "infos" with different keys for env.reset() and env.step()

Yes, this is the case, we don't support it atm (see https://github.com/Farama-Foundation/Minari/issues/191#issuecomment-1991898560).

Either you disable info recording, or you define a StepData class that always return the same keys.

jamartinh commented 2 weeks ago

Ok, you are right, after verifying again the env.reset and env.step return different keys

jamartinh commented 2 weeks ago

I vote for taking an action on this, since there could be arbitrary wrappers adding infos arbitrarily.

For instance, the RecordEpisodeStatistics that adds info entries at the final step , and the "final_observation" entry in info.

Perhaps just saving it as an arbitrary pickled binary string into the storage and then minari unplickles infos ?

At the end is a list of dicts , so it seems as the typical schemaless unstructured data of Documents (e.g. json) databases

younik commented 2 weeks ago

Perhaps just saving it as an arbitrary pickled binary string into the storage and then minari unplickles infos ?

At the end is a list of dicts , so it seems as the typical schemaless unstructured data of Documents (e.g. json) databases

We are using tabular structure as well with PyArrow.

The way to add this feature, as I see it, is to pad the data. We already do something like that inside PyArrow storage. However, I am working on higher priority stuffs atm, so I don't have the bandwidth to work on this in the near future. A PR is appreciated, but it is not a straightforward one. Otherwise I suggest to use StepData callback.

jamartinh commented 2 weeks ago

Ok, I will think deep a propose a solution and possibly a PR. I consider usability is a priority as well

jamartinh commented 1 week ago

I will need help in identifying the pieces of code and docs that needs ti be updated.

The new definition is that infos is a list of heterogeneous dicts, one list per episode.

Now it is currently working for both arrow and hdf5. I have now to made it backward compatible with the current version accepting an homogeneous "infos" dict instead of list of dict.

jamartinh commented 1 week ago

Infos can be saved as a dictionary of np.arrays or as a list of arbitrary dictionaries by setting the optional infos_format parameter. Default is dictionary format infos_format = None or "dict":

from minari import DataCollector
import gymnasium as gym

env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, infos_format="list")