octo-models / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
865 stars 165 forks source link

How to iterate deterministically with RLDS? What am I doing wrong? #131

Open peter-mitrano-bg opened 1 month ago

peter-mitrano-bg commented 1 month ago

I'm unable to iterate deterministically. I've looked at this page and I've tried to set up the config so it iterates deterministically but the order is still random.

Here's my test script.

#!/usr/bin/env python from pathlib import Path import tensorflow as tf from octo.data.dataset import make_single_dataset from octo.utils.spec import ModuleSpec def main(): tf.config.set_visible_devices([], "GPU") dataset_kwargs = { "skip_norm": True, "shuffle": False, "batch_size": 1, "name": "bg_p2_dataset", "data_dir": str(Path("~/tensorflow_datasets").expanduser()), "image_obs_keys": {"primary": "scene", "wrist": "left"}, "proprio_obs_key": "state", "language_key": "language_instruction", "action_proprio_normalization_type": "normal", "action_normalization_mask": [True, True, True, True, True, True, False], "standardize_fn": ModuleSpec.create( "octo.data.oxe.oxe_standardization_transforms:bg_p2_dataset_transform", ), } traj_transform_kwargs = { "window_size": 2, "action_horizon": 4, "goal_relabeling_strategy": None, "task_augment_strategy": "delete_task_conditioning", "task_augment_kwargs": { "keep_image_prob": 0, }, "num_parallel_calls": 1, } primary_img_augment_kwargs = { "random_resized_crop": {"scale": [0.8, 1.0], "ratio": [0.9, 1.1]}, "random_brightness": [0.1], "random_contrast": [0.9, 1.1], "random_saturation": [0.9, 1.1], "random_hue": [0.05], "augment_order": [ "random_resized_crop", "random_brightness", "random_contrast", "random_saturation", "random_hue", ], } wrist_img_augment_kwargs = { "random_brightness": [0.1], "random_contrast": [0.9, 1.1], "random_saturation": [0.9, 1.1], "random_hue": [0.05], "augment_order": [ "random_brightness", "random_contrast", "random_saturation", "random_hue", ], } frame_transform_kwargs = { "crop_size": { "wrist": (int(480 / 2) - 128, int(640 / 2) - 128, 256, 256), }, "resize_size": { "primary": (256, 256), "wrist": (128, 128), }, "image_augment_kwargs": { "primary": primary_img_augment_kwargs, "wrist": wrist_img_augment_kwargs, }, "num_parallel_calls": 1, } dataset_kwargs['num_parallel_calls'] = 1 dataset_kwargs['num_parallel_reads'] = 1 dataset, full_dataset_name = make_single_dataset( dataset_kwargs, traj_transform_kwargs=traj_transform_kwargs, frame_transform_kwargs=frame_transform_kwargs, train=False, ) for k in range(5): print(f"--- {k} ---") train_data_iter = dataset.iterator() for batch in train_data_iter: print(batch['action'].sum()) if __name__ == "__main__": main()

And the output:

WARNING:tensorflow:AutoGraph could not transform and will run it as-is. Cause: Unable to locate the source code of . Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING:tensorflow:AutoGraph could not transform and will run it as-is. Cause: Unable to locate the source code of . Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING:absl:Dataset normalization turned off -- set skip_norm=False to apply normalization. --- 0 --- 184.0448 177.0222 198.7096 189.02994 183.54613 --- 1 --- 184.0448 177.0222 198.7096 189.02994 183.54613 --- 2 --- 184.0448 198.7096 177.0222 189.02994 183.54613
peter-mitrano-bg commented 1 month ago

The motivation here is to have the validation set iteration be deterministic so the plots and such can be compared. Right now the freeze_trajs argument doesn't actually seem to work, so you can't easily compare the visualizations over the course of training.

peter-mitrano-bg commented 2 weeks ago

Just coming back to say I've given this another go and still can't figure out. Being able to deterministically iterate is also pretty essential for making comparisons between different visualization and analysis scripts, so I'm still very interested in help figuring this out!

peter-mitrano-bg commented 1 week ago

It seems like applying https://github.com/kvablack/dlimp/pull/4 solves this problem! But I found it's only a partial implementation