google-deepmind / open_x_embodiment

Apache License 2.0
623 stars 41 forks source link

How to format data compatible with the RT-1 Tensorflow code #22

Closed jrabary closed 7 months ago

jrabary commented 7 months ago

Hi, I'm trying to use the RT-1 Tensorflow code base for training. Following the Colab notebook for data loading I ended up with the following code

def to_trajectory(step):
    # map open-x trajectory to tf_agent Trajectory
    return trajectory.from_episode(
        observation=step["observation"],
        action=step["action"],
        policy_info=(),
        reward=tf.zeros((SEQUENCE_LENGTH,)),
        discount=tf.ones((SEQUENCE_LENGTH,)),
)

# same as in the Colab open-x data loading

trajectory_dataset = trajectory_transform.transform_episodic_rlds_dataset(episodic_dataset)

trajectory_iter = iter(trajectory_dataset)
t = to_trajectory(next(trajectory_iter))

But when I use the t to train the RT-1 sequence agent I get and tensor spec mismatch

Received a mix of batched and unbatched Tensors, or Tensors are not compatible with Specs.  num_outer_dims: 2.
Saw tensor_shapes:
   Trajectory(
{'action': TensorSpecStruct(
{'gripper_closedness_action': TensorShape([1, 3]),
 'rotation_delta': TensorShape([1, 3, 3]),
 'terminate_episode': TensorShape([1, 3, 2]),
 'world_vector': TensorShape([1, 3, 3])}),
 'discount': TensorShape([1, 3]),
 'next_step_type': TensorShape([1, 3]),
 'observation': TensorSpecStruct(
{'image': TensorShape([1, 3, 256, 320, 3]),
 'natural_language_embedding': TensorShape([1, 3, 512])}),
 'policy_info': (),
 'reward': TensorShape([1, 3]),
 'step_type': TensorShape([1, 3])})
And spec_shapes:
   Trajectory(
{'action': TensorSpecStruct(
{'gripper_closedness_action': TensorShape([1]),
 'rotation_delta': TensorShape([3]),
 'terminate_episode': TensorShape([2]),
 'world_vector': TensorShape([3])}),
 'discount': TensorShape([]),
 'next_step_type': TensorShape([]),
 'observation': TensorSpecStruct(
{'image': TensorShape([256, 320, 3]),
 'natural_language_embedding': TensorShape([512])}),
 'policy_info': (),
 'reward': TensorShape([]),
 'step_type': TensorShape([])})

Any ideas ?

jrabary commented 7 months ago

it was a shape problem.