google-deepmind / open_x_embodiment

Apache License 2.0
623 stars 41 forks source link

Overlaying inference results with training data using jax checkpoint #57

Closed zhiyuan-zhang0206 closed 4 days ago

zhiyuan-zhang0206 commented 2 months ago

Hi there, Thanks for your incredible work and open-sourcing. Could you guys give an example of running inference and overlaying inferece results using the jax checkpoint?

I recently ran into a problem trying to use jax checkpoint and reproduce the inference example using tensorflow. When I try to run the tf inference example on the first piece of data of bridge, and visualize only the [15:] frames, from Minimal_example_for_running_inference_using_RT_1_X_TF_using_tensorflow_datasets.ipynb, I get: inference_tf which is pretty good. But when I try to use rt1_inference_example.py and run on the same data, I get: inference_jax which is drastically different from the results using tf checkpoint. The tf checkpoint comes without code (I think) so it's really hard for me to debug. So I am wondering if you can release an example of running inference using the jax checkpoint and overlaying the results with training data? That would really help. Thanks in advance!

jonathansalzer commented 6 days ago

Hi, I'm currently also struggling with this issue. Have you made any progress on this?

zhiyuan-zhang0206 commented 5 days ago

Hi, I'm currently also struggling with this issue. Have you made any progress on this?

Hi jonathan, Vaguely, I remember it was because of the image preprocessing code. The image should be scaled and padded before sending to the network? Try search for "_add_crop_augmentation". I did a lot of work after resolving this issue, so I cannot remember this clearly. Also, if you are going to finetune the model, I encountered a few other bugs here and there. I learned that in Jax you can use a certain flag to turn off the jit compilation; you can use os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" to limit tensorflow GPU use (or use CPU only for tf: tf.config.set_visible_devices([], 'GPU')); you may need to use os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40' to lower the GPU memory cost to avoid OOM error. And the jax code written is quite different from the recommended way on the jax official site, so you may need some more jax knowledge. Good luck.

jonathansalzer commented 5 days ago

Hello,

thank you very much for the detailed answer, I will have another look at the image processing.

Do you remember how you handled the language embedding? The model requires it to be in shape (512,15) to go with the sequence length of 15, however if I simply use the language embedding 15 times I get the same output for every inference step. I solved this by using arrays of ones, and just replacing the last element with the language embedding, like this:

embeddings = [jnp.ones((512,)) for _ in range(15)]
embeddings[-1] = self.language_embedding

observation = {
  'image': img_array,
  'natural_language_embedding': np.array(embeddings),
}

The inference output is not constant anymore and "looks organic", however I am not sure if this is the right approach.

zhiyuan-zhang0206 commented 5 days ago

Hello,

thank you very much for the detailed answer, I will have another look at the image processing.

Do you remember how you handled the language embedding? The model requires it to be in shape (512,15) to go with the sequence length of 15, however if I simply use the language embedding 15 times I get the same output for every inference step. I solved this by using arrays of ones, and just replacing the last element with the language embedding, like this:

embeddings = [jnp.ones((512,)) for _ in range(15)]
embeddings[-1] = self.language_embedding

observation = {
  'image': img_array,
  'natural_language_embedding': np.array(embeddings),
}

The inference output is not constant anymore and "looks organic", however I am not sure if this is the right approach.

Say the observation is 3 images, and you need to pad 12 zero images. Then the language embedding should be: [zeros] 12 + [embedding] 3. Zero image comes with zero language embedding. This is quite reasonable if you think about how FiLM works.

Below is the code that I used to manage the language embedding. I used it in a mujoco environment, each step I call the function to update my language embedding.

def natural_language_embedding_step(self, ):
    if not hasattr(self, 'language_instruction_embedding_list'):
        self.language_instruction_embedding_list = [np.zeros_like(self.language_instruction_embedding) for _ in range(15)]
    self.language_instruction_embedding_list.append(copy.deepcopy(self.language_instruction_embedding))
    self.language_instruction_embedding_list.pop(0)
    return copy.deepcopy(np.stack(self.language_instruction_embedding_list, axis=0))
jonathansalzer commented 4 days ago

Thank you for the hint and for sharing your code. I have tried something similar already, unfortunately it does not work for me. This is what I think should work:

  def run_mock_inference(self, img, step_index):
    image = preprocess_image(img)
    emb = self.language_embedding

    if len(self.img_queue) == 0:
      self.img_queue.extend([jnp.zeros((300,300,3)) for j in range(0,15)])
      self.emb_queue.extend([jnp.zeros((512,)) for j in range(0,15)])

    self.img_queue.append(image)
    self.emb_queue.append(emb)
    # queues implemented as deques with maxlen 15, so no need to remove

    img_array = np.array(self.img_queue)
    emb_array = np.array(self.emb_queue)

    observation = {
      'image': img_array,
      'natural_language_embedding': emb_array,
    }

    return self.policy.action(observation)

The language embedding is calculated like this:

self.embed = hub.load('https://tfhub.dev/google/universal-sentence-encoder-large/5')
self.language_embedding = self.embed(["Place the can to the left of the pot."])[0]

When running inference using images from the bridge dataset, I get constant outputs for all 38 inference steps, as shown here: constantinference

The only way for me to get inference output that is NOT constant is to use jnp.ones((512,)) in at least one place of the language embeddings array. I tried many combinations, and this has been the only way to get the output to change. I have also experimented some more with the image processing, unfortunately without any success. I think I am doing something fundamentally wrong with the language embedding, however I cannot track down my mistake. I have been struggling with this issue for a while now, so if you have any ideas, it would be greatly appreciated.

zhiyuan-zhang0206 commented 4 days ago

Maybe you should first try to use the jax ckpt to replicate the behaviour of the tf ckpt on the example data.

jonathansalzer commented 4 days ago

That's what I'm doing here, this is all with the bridge example data using the published Jax checkpoint. The rest of the code is exactly as in the rt1_inference.py

zhiyuan-zhang0206 commented 4 days ago

Is your image preprocessing like this: def prepare_image(image): image = tf.image.resize_with_pad( image, target_width=320, target_height=256, ) image = tf.image.resize(image, size=(300, 300)) return image.numpy()

And, did you pass in the correct range for world_vector and rotation when initializing the policy? I browsed through the code again and I kind of remember that the example code did not deal with this normalization thing

jonathansalzer commented 4 days ago

I finally caught the error, in my image processing I normalized the image, dividing it by 255, but this is already done further down the pipeline, meaning it was accidentally done twice. I now get the expected results with the bridge data. Thank you very much for all your help, I really appreciate it.

zhiyuan-zhang0206 commented 4 days ago

You are welcome. I think you should always visualize the intermediate results to debug faster. I remember I dumped all the inputs to the network to debug and found the image padding bug.