octo-models / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
660 stars 121 forks source link

Random behavior after fine tuning octo on new robot #29

Open apirrone opened 6 months ago

apirrone commented 6 months ago

Hi !

I managed to convert my (very small for now) training dataset to rlds format, then to run 02_finetune_new_observation_action.py after modifying mainly this to match the action space of my robot and the image key name :

   logging.info("Loading finetuning dataset...")
    dataset = make_single_dataset(
        dataset_kwargs=dict(
            name="ReachyDataset",
            data_dir=FLAGS.data_dir,
            image_obs_keys={"primary": "head"},
            state_obs_keys=["state"],
            language_key="language_instruction",
            action_proprio_normalization_type=NormalizationType.NORMAL,
            absolute_action_mask=[True] * 19,
        ),
[...]

The training curves look like this : Capture d’écran du 2024-01-07 11-25-32

My finetuning data was in the real world (20 examples of 10 seconds each of the robot being teleoperated by me to grab a small wooden cube). Below is an example of one recorded episode episode

Then I tried run inference, inspired by 01_inference_pretrained.ipynb. My code looks like this:

model = OctoModel.load_pretrained("/data1/apirrone/octo/trainings/")
task = model.create_tasks(texts=["Grab the wooden cube"])

while True:
    observation = {
        "image_primary": get_image(),
        "proprio": get_state(),
        "pad_mask": np.array([[True]]),
    }
    actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))[0]
    actions = (
        actions * model.dataset_statistics["action"]["std"]
        + model.dataset_statistics["action"]["mean"]
    )
    for step in actions:
        set_joints(np.array(step))
        time.sleep(0.02)

    time.sleep(0.5)

I am running it on our simulator for now as I don't have access to a robot right now. This is what I get (I actually fixed the head in place in this example):

https://github.com/octo-models/octo/assets/6552564/881dee71-308d-4d97-8118-5ba6480657e8

Is this expected with such a little finetuning dataset ? Would having a lot more data solve it on its own ?

I also have a few warnings that are a little concerning :)

INFO:root:No task inputs matching image_primary were found. Replacing with zero padding.
WARNING:root:'observations' is missing items compared to example_batch: {'pad_mask_dict/timestep', 'timestep', 'pad_mask_dict/image_primary', 'pad_mask_dict/proprio'}
WARNING:root:No pad_mask_dict found. Nothing will be masked.

Overall I don't think I understand properly what I am doing yet. For example :

Any help would be greatly appreciated !

Thanks,

Antoine

PeterMitrano commented 6 months ago

My experience trying to use this has been similar to yours. Some confusion about the padding & masking. I haven't figured out whether those are warnings we need to fix or if we can just ignore them. FWIW their example code with the Aloha sim cube dataset prints those same warnings I believe. Please response here if you learn more!

Also, please take a loot at my issue https://github.com/octo-models/octo/issues/28, I'm curious whether you had the same issues.

kvablack commented 6 months ago

Hi @apirrone,

Thanks for your interest in Octo! It's difficult to diagnose the problem outright, but in general, I would expect real2sim transfer like you're doing to be quite difficult. While Octo helps a lot with generalization, major changes in the scene, camera angle, object appearance, etc. are very difficult to adapt to. It looks like your data has some significant egocentric motion as well, which is going to make things even harder. As a first step, I would recommend finetuning on some simulation data with a fixed camera angle and then evaluating with the exact same simulated setup.

What I can do, at least, is answer your questions/explain the warnings! The pad_mask/pad_mask_dict stuff is admittedly confusing, but was unfortunately necessary to make Octo as flexible as possible.

I don't know what is the role of pad_mask in the observations dict in this context

pad_mask indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set pad_mask=False at the corresponding index. The great thing about transformers is that they're flexible --- at inference time, you don't have to use the previous observation if you don't want to! In your case, and in 01_inference_pretrained, we do inference with a window size of 1, meaning the model makes its prediction based on the current observation alone. So pad_mask should always just be [True], indicating that the one and only observation in the window should be attended to.

What is pad_mask_dict?

While pad_mask indicates which observations should be attended to on a timestep level, pad_mask_dict indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, pad_mask_dict["language_instruction"] is set to False. For datasets without a wrist camera, pad_mask_dict["image_wrist"] is set to False. For convenience, if a key is missing from the observation dict, it is equivalent to setting pad_mask_dict to False for that key. (For example, if you omit image_wrist from the observation dict, even though Octo was pretrained with an image_wrist slot it will still work correctly by simply ignoring that slot.)

Also, when calling model.sample_actions([...]), does this return the full trajectory that is supposed to "solve" the task ? Or should I sample it multiple times with new observations ?

This is up to you! It all depends on the "action chunking" size. Octo was pretrained with an action chunking size of 4, meaning it predicts the next 4 actions at once. You can choose to execute all these actions before sampling new ones, or only execute the first action before sampling new ones (also known as receding horizon control). You can also do something more advanced like temporal ensembling (see utils/gym_wrappers.py).

In 02_finetune_new_observation_action.py, the action chunking size is increased dramatically to 50. I believe they found that this worked best for ALOHA, which uses very high-frequency control. If you're changing action spaces anyway, you can pick an action chunking size that works best for you, although 4 is probably a reasonable default.

INFO:root:No task inputs matching image_primary were found. Replacing with zero padding.
WARNING:root:'observations' is missing items compared to example_batch: {'pad_mask_dict/timestep', 'timestep', 'pad_mask_dict/image_primary', 'pad_mask_dict/proprio'}
WARNING:root:No pad_mask_dict found. Nothing will be masked.

The first one is expected when you are doing language conditioning --- when you use create_tasks without an images argument, image-related keys are omitted. The Octo model then inserts zero padding instead, matching what was done during training for language conditioning. The second warning can be ignored --- we check the inputs to the Octo model against the inputs that were passed during training, but the check is a bit overly strict. timestep and proprio are unused anyway, and pad_mask_dict is not required (as indicated by the final warning). The final warning indicates that since you didn't include a pad_mask_dict in the input to the model, it will not mask out anything, and everything you passed in will be attended to.

apirrone commented 6 months ago

Thank you very much @kvablack for these very helpful explainations ! Things are much clearer to me now :)

I will try some things on the real robot and report back !

Another question, you said

[...] timestep and proprio are unused anyway [...] 

This is consistent with the output I get when printing model.get_pretty_spec() (I already made made modifications to the fine tuning script to have a window size of 2 and only 4 steps in the future FYI)

This model is trained with a window size of 2, predicting 19 dimensional actions 4 steps into the future.
Observations and tasks conform to the following spec:

Observations: {
    image_primary: ('batch', 'history_window', 256, 256, 3),
}
Tasks: {
    language_instruction: {
        attention_mask: ('batch', 16),
        input_ids: ('batch', 16),
    },
}

How enable the use of proprio ?

Thanks !

apirrone commented 6 months ago

@PeterMitrano @kvablack's explainations should help you a lot as well !

I did not encounter the error you mention in #28. Were you able to get past this ?

apirrone commented 6 months ago

So I finetuned with a window of size 2, and 4 steps in the future. Still with my very small dataset (it may be the bottleneck here).

The training curves look like this image

And now my inference code look like this :

prev_im = None
prev_state = None
while True:
    im = get_image()
    state = get_state()

    if prev_im is None:
        ims = np.expand_dims(np.stack((im, im)), axis=0)
        states = np.expand_dims(np.stack((state, state)), axis=0)
    else:
        ims = np.expand_dims(np.stack((prev_im, im)), axis=0)
        states = np.expand_dims(np.stack((prev_state, state)), axis=0)

    pad_mask = np.array([[False if prev_state is None else True, True]])

    observations = {
        "image_primary": ims,
        "proprio": states,
        "pad_mask": pad_mask,
    }

    start = time.time()
    actions = model.sample_actions(observations, task, rng=jax.random.PRNGKey(0))[0]
    print("Sampling actions took ", time.time() - start, " seconds")

    prev_state = state
    prev_im = im

    # Unnormalize
    actions = (
        actions * model.dataset_statistics["action"]["std"]
        + model.dataset_statistics["action"]["mean"]
    )

    set_joints(np.array(actions[0]))
    time.sleep(0.02)

I only execute the first action of the 4 predicted (it eliminates the jerkyness of the motion). Now the robot behaves like this https://photos.app.goo.gl/22tBp5kApTL1sgwWA (video too large for github embedding)

It's behavior is definetly affected by what it sees, but it does not really try to grab the cube :)

Any idea ?

Thanks !

PeterMitrano commented 6 months ago

@apirrone I was able to get past most of my problems. @kvablack perhaps you could copy your explanation of the padding into a "Padding explained" in a readme somewhere for increased visiblity. For instance, I had assumed pad_mask being 1 means mask it out not in so now it makes way more sense with your explanation.

I'm currently using this

    env = ConqGymEnv(clients, FLAGS.im_size, FLAGS.blocking)
    env = add_octo_env_wrappers(env, model.config, dict(model.dataset_statistics), normalization_type="normal",
                                resize_size=(FLAGS.im_size, FLAGS.im_size), exec_horizon=FLAGS.exec_horizon)

to address the image size issues I was having. Figuring out exactly what my obs dict needed to contain was tricky, but it seems to be working now. The robot's behavior is still.... silly. But I'm thinking that's an issue with my data or action space or something.

mees commented 6 months ago

@apirrone I was able to get past most of my problems. @kvablack perhaps you could copy your explanation of the padding into a "Padding explained" in a readme somewhere for increased visiblity. For instance, I had assumed pad_mask being 1 means mask it out not in so now it makes way more sense with your explanation.

Yeah that's a great suggestion, we will merge it in this PR #30

apirrone commented 6 months ago

Hi @kvablack and @mees

I still can't get my robot to do anything relevant. One time It almost grabbed the object after a few minutes of hesitation, that was very exciting but I think it was a fluke :) Mostly, it goes to a position and moves very little after that.

Here is what I tried:

What I haven't tried yet:

By the way, how should the example episodes be structured ? Should an episode contain only one movement ? or a sequence of movements ? My understanding is that by fine tuning the model that way, I mostly train it to yield coherent outputs regarding to the robot's geometry, is that right ? After that I could get it to do different tasks that have been trained in the pre-trained model ? Or am I mistaken ?

Also, maybe there is a problem with the way I collect data. I originally made the data collection script for the mobile-aloha model, so I sample images and joint positions at 30Hz. Maybe this is not what octo is supposed to work with ?

Thanks !

kpertsch commented 6 months ago

Hi Antoine,

A few recommendations:

apirrone commented 6 months ago

Great, I will try with your recommendations, Thanks !

Another question, do I need 50-100 demos per task I want to learn ? Octo is described as a "generalist policy" so I understood the fine tuning step is required only for the model to learn to move properly with a new robot. For example, can I change the text in the model.create_tasks() function to "grab the bottle", while there was no bottle grabbing in the fine tuning dataset ?

apirrone commented 6 months ago

Update:

I followed (most of) your recommendations:

The robot still does not do anything really relevant. For example: https://photos.app.goo.gl/YrY24ch7UQtc2Qn19

It mostly stays around a position that seems to somewhat depend on the position of the object does not do much more.

Does that sound like there is an issue with the fine-tuning pipeline or the inference pipeline ?

What I will try next :

Thanks !

PeterMitrano commented 6 months ago

I've had some initial success with my setup! I also copied the settings suggested by kevin and karl. I haven't run enough tests to see which of these settings changes caused things to start working (dataset, model, fine-tuning hyper-paramets, etc...)

Success here means "it moves down". Before it wasn't even moving or would move in bizarre ways. Now I'm going to hack around some more and see if I can get it to actually grasp the thing :)

https://github.com/octo-models/octo/assets/4010770/25a0859d-b7c9-4b77-a02a-119ccdc5568f

HM102 commented 5 months ago

any update, does the fine tuning work?

PeterMitrano commented 5 months ago

Yes, it is working for me. I had to fix a few issues with my data.

HM102 commented 1 month ago

@PeterMitrano what kind of issues?

PeterMitrano commented 1 month ago

I think the main one was the action space being in the wrong coordinate frame. I advise checking that replaying your training data through your execution pipeline works before you bother training a model.