google-research / robotics_transformer

Apache License 2.0
1.29k stars 148 forks source link

What is the correct way to restore the checkpoints? #11

Closed ka2hyeon closed 1 year ago

ka2hyeon commented 1 year ago

When I run

tf.saved_model.load('./robotics_transformer/trained_checkpoints/rt1main')

I got a following error,

IndexError: Read less bytes than requested

All the efforts to restore the checkpoint you provided were failed. For example, the following code also not worked for me.

from tf_agents.utils.common import Checkpointer
checkpointer= Checkpointer(
      agent=agent,
      ckpt_dir='./robotics_transformer/trained_checkpoints/rt1main'
  )
checkpointer.initialize_or_restore()

What is the correct way to restore the checkpoints?

yaolug commented 1 year ago

The second way should be correct. How is agent defined?

Could you also try

checkpointer= Checkpointer(
      agent=agent,
      ckpt_dir='./robotics_transformer/trained_checkpoints/rt1main',
      global_step=424760
  )
ka2hyeon commented 1 year ago

Thank you for reply. However, specifying global_step argument didn't work either.

The agent is instance of SequenceAgent created in ./sequence_agent_test.py. In my opinion, this error comes from the mismatching of parameters between the RT-1 model declared in your test code and the checkpoint model. I need parameters used for training the checkpoint, but they cannot be inferred from only with your code.

For example, I wonder which value of following parameters are used in the checkpoint model: num_layers (initialized as 1), layer_size (initialized as 4096), num_heads (initialized as 8), feed_forward_size (initialized as 512). Some parameters can be inferred from your paper (e.g. time_sequence_length=6, vocab_size=256), but I cannot know parameters not mentioned the paper.

yaolug commented 1 year ago

See configs/transformer_mixin.gin for parameters.

Also for savedmodel, the following could work

from tf_agents.policies import py_tf_eager_policy

py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    model_path='./robotics_transformer/trained_checkpoints/rt1main',
    load_specs_from_pbtxt=True,
    use_tf_function=True,
)
ka2hyeon commented 1 year ago

All my problem solved! Thank you for a kind help. tf.saved_model.load('...') didn't work, but py_tf_eager_policy.SavedModelPyTFEagerPolicy(...) worked. Also, I missed transformer_mixin.gin, and now I can restore all correct parameters.

AliBuildsAI commented 1 year ago

@ka2hyeon could you please share your environemnt via pip list? I tried both methods on both tf1 and tf2 and got errors.

ka2hyeon commented 1 year ago

@AliBuildsAI In my environment, I am using following tensorflow-related packages.

tensorboard 2.8.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 tensorflow 2.8.2 tensorflow-addons 0.17.1 tensorflow-datasets 4.6.0 tensorflow-estimator 2.8.0 tensorflow-hub 0.12.0 tensorflow-io-gcs-filesystem 0.26.0 tensorflow-metadata 1.9.0 tensorflow-model-optimization 0.7.2 tensorflow-probability 0.16.0 tensorflow-text 2.8.2 tf-agents 0.12.0

oym1994 commented 1 year ago

All my problem solved! Thank you for a kind help. tf.saved_model.load('...') didn't work, but py_tf_eager_policy.SavedModelPyTFEagerPolicy(...) worked. Also, I missed transformer_mixin.gin, and now I can restore all correct parameters.

Hi, could you please provide the complete code of py_tf_eager_policy.SavedModelPyTFEagerPolicy(...)? Thank you!!!