TuragaLab / flybody

MuJoCo fruit fly body model and reinforcement learning tasks
Apache License 2.0
90 stars 8 forks source link

Mismatched action spec size in `flight_imitation` and `vision_guided_flight` envs #3

Closed talmo closed 5 months ago

talmo commented 5 months ago

Hey guys,

I'm trying to run the docs/fly-env-examples.ipynb notebook and having issues with the flight imitation environment.

Everything up to here works fine and I can render the camera:

env = flight_imitation(wpg_pattern_path,
                       ref_flight_path,
                       terminal_com_dist=float('inf'))
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)

_ = env.reset()
pixels = env.physics.render(camera_id=1, **render_kwargs)

Running the next cell throws an error though:

random_policy = get_random_policy(env.action_spec())

frames = rollout_and_render(env, random_policy, run_until_termination=True,
                            camera_ids=1, **render_kwargs)
display_video(frames)

Raises:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 3
      1 random_policy = get_random_policy(env.action_spec())
----> 3 frames = rollout_and_render(env, random_policy, run_until_termination=True,
      4                             camera_ids=1, **render_kwargs)
      5 display_video(frames)

File [~/flybody/flybody/utils.py:33](http://localhost:8888/lab/tree/docs/flybody/utils.py#line=32), in rollout_and_render(env, policy, n_steps, run_until_termination, camera_ids, **render_kwargs)
     31     frames.append(frame)
     32     action = policy(timestep.observation)
---> 33     timestep = env.step(action)
     34 return frames

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py:52](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py#line=51), in CanonicalSpecWrapper.step(self, action)
     50 def step(self, action: types.NestedArray) -> dm_env.TimeStep:
     51   scaled_action = _scale_nested_action(action, self._action_spec, self._clip)
---> 52   return self._environment.step(scaled_action)

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py:37](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py#line=36), in SinglePrecisionWrapper.step(self, action)
     36 def step(self, action) -> dm_env.TimeStep:
---> 37   return self._convert_timestep(self._environment.step(action))

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:416](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=415), in Environment.step(self, action)
    413   self._reset_next_step = False
    414   return self.reset()
--> 416 self._hooks.before_step(self._physics_proxy, action, self._random_state)
    417 self._observation_updater.prepare_for_next_control_step()
    419 try:

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:137](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=136), in _EnvironmentHooks.before_step(self, physics, action, random_state)
    134 if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
    135   logging.info('The current episode has been running for %d steps.',
    136                self._episode_step_count)
--> 137 self._task.before_step(physics, action, random_state)
    138 for entity_hook in self._before_step.entity_hooks:
    139   entity_hook(physics, random_state)

File [~/flybody/flybody/tasks/flight_imitation.py:164](http://localhost:8888/lab/tree/docs/flybody/tasks/flight_imitation.py#line=163), in FlightImitationWBPG.before_step(self, physics, action, random_state)
    160 self._ghost.set_pose(physics, ghost_qpos[:3], ghost_qpos[3:])
    161 self._ghost.set_velocity(physics, self._ref_qvel[step, :3],
    162                          self._ref_qvel[step, 3:])
--> 164 super().before_step(physics, action, random_state)

File [~/flybody/flybody/tasks/base.py:201](http://localhost:8888/lab/tree/docs/flybody/tasks/base.py#line=200), in FruitFlyTask.before_step(self, physics, action, random_state)
    199 if self._action_corruptor is not None:
    200     action = self._action_corruptor(action, random_state)
--> 201 self._walker.apply_action(physics, action, random_state)

File [~/flybody/flybody/fruitfly/fruitfly.py:502](http://localhost:8888/lab/tree/docs/flybody/fruitfly/fruitfly.py#line=501), in FruitFly.apply_action(***failed resolving arguments***)
    500     return
    501 # Update previous action.
--> 502 self._prev_action[:] = action
    503 # Apply MuJoCo actions.
    504 ctrl = np.zeros(physics.model.nu)

ValueError: could not broadcast input array from shape (12,) into shape (11,)

Inspecting the env, it looks like it should be 12-dimensional:

env.action_spec()
BoundedArray(shape=(12,), dtype=dtype('float32'), name='head_abduct\thead_twist\thead\twing_yaw_left\twing_roll_left\twing_pitch_left\twing_yaw_right\twing_roll_right\twing_pitch_right\tabdomen_abduct\tabdomen\tuser_0', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.], maximum=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])

Further confirming this, it looks like the provided checkpoint from figshare also has a 12-dimensional action space:

flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)

zero_obs = {k: tf.zeros(v.shape, dtype=v.dtype) for k, v in env.observation_spec().items()}
act = flight_policy(zero_obs)
act.shape
(12,)

Here's a more self contained version without the rollout_and_render wrapper:

env = flight_imitation(wpg_pattern_path,
                       ref_flight_path,
                       terminal_com_dist=float('inf'))
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)

flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)

print("env.observation_spec():")
print(env.observation_spec())
print()

print("env.action_spec():")
print(env.action_spec())
print()

timestep = env.reset()
print("timestep.observation from env.reset():")
print({k: v.shape for k, v in timestep.observation.items()})
print()

action = flight_policy(timestep.observation)
print("flight_policy(timestep.observation) action output:")
print(action.shape)
print()

timestep = env.step(action)  # throws error
env.observation_spec():
OrderedDict([('walker/accelerometer', Array(shape=(3,), dtype=dtype('float32'), name='walker/accelerometer')), ('walker/actuator_activation', Array(shape=(0,), dtype=dtype('float32'), name='walker/actuator_activation')), ('walker/gyro', Array(shape=(3,), dtype=dtype('float32'), name='walker/gyro')), ('walker/joints_pos', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_pos')), ('walker/joints_vel', Array(shape=(25,), dtype=dtype('float32'), name='walker/joints_vel')), ('walker/velocimeter', Array(shape=(3,), dtype=dtype('float32'), name='walker/velocimeter')), ('walker/world_zaxis', Array(shape=(3,), dtype=dtype('float32'), name='walker/world_zaxis')), ('walker/ref_displacement', Array(shape=(6, 3), dtype=dtype('float32'), name='walker/ref_displacement')), ('walker/ref_root_quat', Array(shape=(6, 4), dtype=dtype('float32'), name='walker/ref_root_quat'))])

env.action_spec():
BoundedArray(shape=(12,), dtype=dtype('float32'), name='head_abduct\thead_twist\thead\twing_yaw_left\twing_roll_left\twing_pitch_left\twing_yaw_right\twing_roll_right\twing_pitch_right\tabdomen_abduct\tabdomen\tuser_0', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.], maximum=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])

timestep.observation from env.reset():
{'walker/accelerometer': (3,), 'walker/actuator_activation': (0,), 'walker/gyro': (3,), 'walker/joints_pos': (25,), 'walker/joints_vel': (25,), 'walker/velocimeter': (3,), 'walker/world_zaxis': (3,), 'walker/ref_displacement': (6, 3), 'walker/ref_root_quat': (6, 4)}

flight_policy(timestep.observation) action output:
(12,)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[22], line 31
     28 print(action.shape)
     29 print()
---> 31 timestep = env.step(action)  # throws error

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py:52](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/canonical_spec.py#line=51), in CanonicalSpecWrapper.step(self, action)
     50 def step(self, action: types.NestedArray) -> dm_env.TimeStep:
     51   scaled_action = _scale_nested_action(action, self._action_spec, self._clip)
---> 52   return self._environment.step(scaled_action)

File [~/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py:37](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/acme/wrappers/single_precision.py#line=36), in SinglePrecisionWrapper.step(self, action)
     36 def step(self, action) -> dm_env.TimeStep:
---> 37   return self._convert_timestep(self._environment.step(action))

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:416](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=415), in Environment.step(self, action)
    413   self._reset_next_step = False
    414   return self.reset()
--> 416 self._hooks.before_step(self._physics_proxy, action, self._random_state)
    417 self._observation_updater.prepare_for_next_control_step()
    419 try:

File [~/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py:137](http://localhost:8888/home/talmo/conda/envs/flybody/lib/python3.10/site-packages/dm_control/composer/environment.py#line=136), in _EnvironmentHooks.before_step(self, physics, action, random_state)
    134 if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
    135   logging.info('The current episode has been running for %d steps.',
    136                self._episode_step_count)
--> 137 self._task.before_step(physics, action, random_state)
    138 for entity_hook in self._before_step.entity_hooks:
    139   entity_hook(physics, random_state)

File [~/flybody/flybody/tasks/flight_imitation.py:164](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/tasks/flight_imitation.py#line=163), in FlightImitationWBPG.before_step(self, physics, action, random_state)
    160 self._ghost.set_pose(physics, ghost_qpos[:3], ghost_qpos[3:])
    161 self._ghost.set_velocity(physics, self._ref_qvel[step, :3],
    162                          self._ref_qvel[step, 3:])
--> 164 super().before_step(physics, action, random_state)

File [~/flybody/flybody/tasks/base.py:201](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/tasks/base.py#line=200), in FruitFlyTask.before_step(self, physics, action, random_state)
    199 if self._action_corruptor is not None:
    200     action = self._action_corruptor(action, random_state)
--> 201 self._walker.apply_action(physics, action, random_state)

File [~/flybody/flybody/fruitfly/fruitfly.py:502](http://localhost:8888/lab/workspaces/auto-1/tree/docs/flybody/fruitfly/fruitfly.py#line=501), in FruitFly.apply_action(***failed resolving arguments***)
    500     return
    501 # Update previous action.
--> 502 self._prev_action[:] = action
    503 # Apply MuJoCo actions.
    504 ctrl = np.zeros(physics.model.nu)

ValueError: could not broadcast input array from shape (12,) into shape (11,)

FWIW the walk_imitation env works fine, but the vision_guided_flight env crashes with the same error.

talmo commented 5 months ago

FYI: This is indeed fixed in 7e38acc and I can now run that entire notebook :)

vaxenburg commented 5 months ago

Thanks again for catching this @talmo. Right, this should be fixed now.