Further confirming this, it looks like the provided checkpoint from figshare also has a 12-dimensional action space:
flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)
zero_obs = {k: tf.zeros(v.shape, dtype=v.dtype) for k, v in env.observation_spec().items()}
act = flight_policy(zero_obs)
act.shape
(12,)
Here's a more self contained version without the rollout_and_render wrapper:
env = flight_imitation(wpg_pattern_path,
ref_flight_path,
terminal_com_dist=float('inf'))
env = wrappers.SinglePrecisionWrapper(env)
env = wrappers.CanonicalSpecWrapper(env, clip=True)
flight_policy = tf.saved_model.load(flight_policy_path)
# Wrap policy to work with non-batched observations at test time.
flight_policy = TestPolicyWrapper(flight_policy)
print("env.observation_spec():")
print(env.observation_spec())
print()
print("env.action_spec():")
print(env.action_spec())
print()
timestep = env.reset()
print("timestep.observation from env.reset():")
print({k: v.shape for k, v in timestep.observation.items()})
print()
action = flight_policy(timestep.observation)
print("flight_policy(timestep.observation) action output:")
print(action.shape)
print()
timestep = env.step(action) # throws error
Hey guys,
I'm trying to run the docs/fly-env-examples.ipynb notebook and having issues with the flight imitation environment.
Everything up to here works fine and I can render the camera:
Running the next cell throws an error though:
Raises:
Inspecting the env, it looks like it should be 12-dimensional:
Further confirming this, it looks like the provided checkpoint from figshare also has a 12-dimensional action space:
Here's a more self contained version without the
rollout_and_render
wrapper:FWIW the
walk_imitation
env works fine, but thevision_guided_flight
env crashes with the same error.