google-deepmind / open_x_embodiment

Apache License 2.0
634 stars 42 forks source link

Accessing tf.Module object for RT-1-X checkpoint #10

Open victor-agilerobots opened 8 months ago

victor-agilerobots commented 8 months ago

Hi, Thanks for sharing the datasets and the models.

I am trying to get the RT-1-X model (rt_1_x_tf_trained_for_002272480_step.zip) working in PyTorch by transpiling the TensorFlow model using Ivy. I am able to load the model in a TF Agents object and make inferences, but I have inspected the TF Agents and policy objects and I can't find a way to extract a tf.Module object with the RT-1-X model from there, which is the object I need to transpile to PyTorch using Ivy. I can only see the layers and weights in tf.Variable and tensor format only, but not the tf.Module object. Is there a way to get this object?

Also, what is the reason to use TF Agents in this case? The model is trained using imitation learning, not reinforcement learning, so TF Agents wouldn't be needed, right?

Thank you very much in advance.

Kind regards,

Victor Montesinos

quanvuong commented 7 months ago

I am not quite sure how to obtain the tf.Module object from the checkpoint. Are you trans-piling the model to Pytorch to run fine-tuning? We are planning to release a Jax version of the RT-1-X model and the necessary code for fine-tuning the Jax model.

victor-agilerobots commented 7 months ago

Thank you for your response.

Yes, we are trying to transpile the RT-1 and RT-1-X model to PyTorch for fine-tuning. Mainly because (a) we were not able to fine-tune the model after loading the checkpoint, and (b) we are not able to access the tf.Module object within the checkpoint (for any of the 3 available RT-1 models nor the RT-1-X model).

Will the Jax version (weights and fine-tuning code) be open-source? Do you have an idea of when it will be released?

Thank you very much again for sharing your work.

Guanbin-Huang commented 3 months ago

@victor-agilerobots Did you make it? transpile it from tf to torch? if ok, i also want a torch version, which is prefered in quantization case.

psanketi commented 3 months ago

Hi, we recently open sourced the RT-1-X Jax model and the training sample code with it. Can you please take a look at Training Example Colab? The RT-1-X jax checkpoint that can be used by the flax checkpoint loader in the rt1_inference_example.py can be downloaded by gsutil -m cp -r gs://gdm-robotics-open-x-embodiment/open_x_embodiment_and_rt_x_oss/rt_1_x_jax.