google-research / robotics_transformer

Apache License 2.0
1.32k stars 152 forks source link

How to use loaded checkpoint? #18

Closed destroy314 closed 1 year ago

destroy314 commented 1 year ago

I'm trying to pass time_step into the policy for inference, but it raises the following error:

InvalidArgumentError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Graph execution error:

Detected at node 'transformer_network_1/Reshape_7' defined at (most recent call last):
Node: 'transformer_network_1/Reshape_7'
Input to reshape is a tensor with 0 values, but the requested shape has 1
     [[{{node transformer_network_1/Reshape_7}}]] [Op:__inference_restored_function_body_123572]
  File "/home/robot/miniconda3/envs/rt1/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute

The code I am using is attached below and I would like to know how to use the model provided. The format of the TimeStep object is defined with reference to the time_step_spec of the loaded policy.

import tensorflow as tf
from tf_agents.policies import py_tf_eager_policy
from tf_agents.trajectories.time_step import StepType, TimeStep

HEIGHT = 256
WIDTH = 320

policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    model_path="src/rt1_ros/src/robotics_transformer/trained_checkpoints/rt1main",
    load_specs_from_pbtxt=True,
    use_tf_function=True,
)

def create_time_step(seed, t):
    np.random.seed(seed)
    observations = {
        "image": tf.constant(
            0,
            shape=[1, HEIGHT, WIDTH, 3],
            dtype=tf.dtypes.uint8,
        ),
        "natural_language_embedding": tf.constant(
            0.0,
            shape=(
                1,
                512,
            ),
            dtype=tf.dtypes.float32,
        ),
        "natural_language_instruction": tf.constant(
            "",
            shape=(1,),
            dtype=tf.dtypes.string,
        ),
        "workspace_bounds": tf.constant(
            0.0,
            shape=(1, 3, 3),
            dtype=tf.dtypes.float32,
        ),
        "base_pose_tool_reached": tf.constant(
            0.0,
            shape=(
                1,
                7,
            ),
            dtype=tf.dtypes.float32,
        ),
        "gripper_closed": tf.constant(
            0.0,
            shape=(
                1,
                1,
            ),
            dtype=tf.dtypes.float32,
        ),
        "gripper_closedness_commanded": tf.constant(
            0.0,
            shape=(
                1,
                1,
            ),
            dtype=tf.dtypes.float32,
        ),
        "height_to_bottom": tf.constant(
            0.0,
            shape=(
                1,
                1,
            ),
            dtype=tf.dtypes.float32,
        ),
        "orientation_box": tf.constant(
            0.0,
            shape=(1, 2, 3),
            dtype=tf.dtypes.float32,
        ),
        "orientation_start": tf.constant(
            0.0,
            shape=(
                1,
                4,
            ),
            dtype=tf.dtypes.float32,
        ),
        "robot_orientation_positions_box": tf.constant(
            0.0,
            shape=(1, 3, 3),
            dtype=tf.dtypes.float32,
        ),
        "rotation_delta_to_go": tf.constant(
            0.0,
            shape=(
                1,
                3,
            ),
            dtype=tf.dtypes.float32,
        ),
        "src_rotation": tf.constant(
            0.0,
            shape=(
                1,
                4,
            ),
            dtype=tf.dtypes.float32,
        ),
        "vector_to_go": tf.constant(
            0.0,
            shape=(
                1,
                3,
            ),
            dtype=tf.dtypes.float32,
        ),
    }
    time_step = TimeStep(
        observation=observations,
        reward=tf.constant(
            0.0,
            shape=(1,),
            dtype=tf.dtypes.float32,
        ),
        discount=tf.constant(
            0.0,
            shape=(1,),
            dtype=tf.dtypes.float32,
        ),
        step_type=tf.constant(
            t,
            shape=(1,),
            dtype=tf.dtypes.int32,
        ),
    )
    return time_step

time_step = create_time_step()
policy_state = policy.get_initial_state()

action_step = policy.action(time_step, policy_state)
destroy314 commented 1 year ago

After load model with policy = tf.saved_model.load("path") and policy_state = policy.get_initial_state(1), the previous code works for me. I'm using tensorflow 2.13 and tf-agents 0.17