google-research / robotics_transformer

Apache License 2.0
1.29k stars 148 forks source link

Value Error occured while calling the restored 'action' function from saved model #22

Closed lipzh5 closed 10 months ago

lipzh5 commented 10 months ago

Hi there!

I was trying to use the RT-1 saved model but encountered a 'Value Error' as shown below. Did I do something wrong with the arguments (i.e., time_step and policy_state)? I am new to TensorFlow and any suggestion will be appreciated!

Many thanks in advance.


ValueError Traceback (most recent call last) in <cell line: 3>() 1 time_step = create_time_step() 2 init_state = policy.get_initial_state(1) ----> 3 policy_state = policy.action(time_step, init_state)

1 frames /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.traceback) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/function_deserialization.py in restored_function_body(*args, **kwargs) 299 "Option {}:\n {}\n Keyword arguments: {}".format( 300 index + 1, _pretty_format_positional(positional), keyword)) --> 301 raise ValueError( 302 "Could not find matching concrete function to call loaded from the " 303 f"SavedModel. Got:\n {_pretty_format_positional(args)}\n Keyword "

ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got: Positional arguments (3 total):

Option 1: Positional arguments (3 total):

Option 2: Positional arguments (3 total):

the code snippets are shown as follows:

from tf_agents.policies import py_tf_eager_policy
model_path = os.path.join(os.getcwd(), '../trained_checkpoints/rt1main')
print('model path', model_path)
policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    model_path=model_path,
    load_specs_from_pbtxt=True,
    use_tf_function=True)
def create_time_step(seed=123,t=0):
    import numpy as np
    from tf_agents.trajectories.time_step import StepType, TimeStep
    HEIGHT, WIDTH = 256, 320
    np.random.seed(seed)
    observations = {
    'orientation_start': tf.constant(0.0, shape=(1, 4), dtype=tf.dtypes.float32),
    'natural_language_instruction': tf.constant('', shape=(1,), dtype=tf.dtypes.string),
    'rotation_delta_to_go': tf.constant(0.0, shape=(1, 3), dtype=tf.dtypes.float32),
    'natural_language_embedding': tf.constant(0.0, shape=(1, 512), dtype=tf.dtypes.float32),
    'vector_to_go': tf.constant(0.0, shape=(1, 3), dtype=tf.dtypes.float32),
    'height_to_bottom': tf.constant(0.0, shape=(1,1), dtype=tf.dtypes.float32),
    'src_rotation': tf.constant(0.0, shape=(1, 4), dtype=tf.dtypes.float32),
    'image': tf.constant(0, shape=(1, 256, 320, 3), dtype=tf.dtypes.uint8),
    'gripper_closed': tf.constant(0.0, shape=(1,1), dtype=tf.dtypes.float32),
    'base_pose_tool_reached': tf.constant(0.0, shape=(1, 7), dtype=tf.dtypes.float32),
    'orientation_box': tf.constant(0.0, shape=(1, 2, 3), dtype=tf.dtypes.float32),
    'robot_orientation_position_box': tf.constant(0.0, shape=(1,3,3), dtype=tf.dtypes.float32),
    'workspace_bounds':tf.constant(0.0, shape=(1, 3, 3), dtype=tf.dtypes.float32),
    'gripper_closedness_commanded': tf.constant(0.0, shape=(1, 1), dtype=tf.dtypes.float32),

  }
    time_step = TimeStep(
        observation=observations,
        reward=tf.constant(
            0.0, shape=(1, ), dtype=tf.dtypes.float32,
        ),
        discount=tf.constant(
            0.0, shape=(1, ), dtype=tf.dtypes.float32,
        ),
        step_type=tf.constant(
            t, shape=(1, ), dtype=tf.dtypes.int32,
        ),
    )
    return time_step
time_step = create_time_step()
init_state = policy.get_initial_state(1)
policy_state = policy.action(time_step, init_state)
lipzh5 commented 10 months ago

Oh! Silly me! Something wrong with the 'time_step' argument. When I pass another TimeStep instance to the restored 'action' function, the ValueError disappeared. Here goes the code snippet of constructing a legal TimeStep instance:

# inspect the specbp file
spec_path = os.path.join(model_path, policy_saver.POLICY_SPECS_PBTXT)
policy_specs = policy_saver.specs_from_collect_data_spec(
tensor_spec.from_pbtxt_file(spec_path))
time_step_spec = policy_specs['time_step_spec']
def gen_obs(time_step_spec=time_step_spec):
    obs = {}
    obs_spec = time_step_spec.observation # dict
    for spec_name, spec in obs_spec.items():
        # Note here! should be spec_name not spec.name
        if spec.dtype in (np.dtype(np.float32), np.dtype(np.uint8), np.dtype(np.int32)):
            obs[spec_name] = tf.constant(1, shape=(1,)+spec.shape, dtype=spec.dtype, name=spec_name)
        else:
            obs[spec_name] = tf.constant('', shape=(1,)+spec.shape, dtype=tf.string, name=spec_name)
    return obs
obs = gen_obs()
time_step = ts.restart(obs, batch_size=1)