octo-models / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
879 stars 165 forks source link

Question: What is the intended use case for task stack keys? #25

Closed truncs closed 10 months ago

truncs commented 10 months ago

I am trying to pretrain on a dataset and my intended use case is to have three images tokenized as inputs to the transformers and an action head with 2 outputs. When I run the script I do see that this is indeed the case -

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                  ┃ t=0 obs_instruction (16 tokens)  ┃ t=0 obs_primary (16 tokens)  ┃ t=0 obs_secondary (16 tokens)  ┃ t=0 readout_action (1 tokens)  ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ t=0 obs_instruction (16 tokens)  │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_primary (16 tokens)      │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_secondary (16 tokens)    │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 readout_action (1 tokens)    │                                  │                              │                                │ x                              │
└──────────────────────────────────┴──────────────────────────────────┴──────────────────────────────┴────────────────────────────────┴────────────────────────────────┘

But I get these as INFO messages

I1229 11:29:38.141532 140098393864000 tokenizers.py:123] No task inputs matching image_instruction were found. Replacing with zero padding.
I1229 11:29:38.191684 140098393864000 tokenizers.py:123] No task inputs matching image_primary were found. Replacing with zero padding.
I1229 11:29:38.241426 140098393864000 tokenizers.py:123] No task inputs matching image_secondary were found. Replacing with zero padding.

My input config for observation is the following

    config["model"]["observation_tokenizers"] = {
        "primary": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_primary"],
            task_stack_keys=["image_primary"],
            encoder=ModuleSpec.create(SmallStem16),
        ),
        "secondary": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_secondary"],
            task_stack_keys=["image_secondary"],
            encoder=ModuleSpec.create(SmallStem16),
        ),
        "instruction": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_instruction"],
            task_stack_keys=["image_instruction"],
            encoder=ModuleSpec.create(SmallStem16),
        ),

    }

Going through the image tokenizer code seems like there is obs_stack_keys in the case I want to stack the input? And then there is task input which I am not sure what is it meant for? Am I doing this in the right way?

dibyaghosh commented 10 months ago

Thanks for the question! We use task_stack_keysas a mechanism to do goal-image conditioning.

The image tokenizer roughly implements the following logic:

inputs = jnp.concatenate(
     [observations[k] for k in obs_stack_keys] + 
     [tasks[k] for k in task_stack_keys],
     axis=-1
)
tokens = encoder(inputs)

So, when you configure the tokenizer this way

"primary": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_primary"],
            task_stack_keys=["image_primary"],
            encoder=ModuleSpec.create(SmallStem16),
        ),

Inside the tokenizer, the "image_primary" key is extracted from the "observations" dictionary, the "image_primary" key is extracted from the tasks dictionary, and the two are concatenated channel-wise, before being passed into the conv layers. This is known as early-goal fusion, and means that from the very beginning of the network, the model can do pixel-wise comparisons between the camera view at the current timestep and the desired goal camera view (a typically useful inductive bias for goal-reaching tasks).


If you don't care about goal-image task conditioning (e.g. you only want language-conditioned training), then you should simply omit the task_stack_keys argument (same if you want to do goal-image conditioning, but would prefer to separately encode / tokenized the goal image and the current observation).

In any case, what is happening in your current code is that the config is expecting a goal image corresponding to "image_primary" in tasks["image_primary"], is not finding it in the tasks dictionary, and choosing to just insert a black image in its place (effectively a no-op).

truncs commented 10 months ago

I actually don't have a goal image but instead I have a image trajectory prompt (similar to https://arxiv.org/abs/2311.01977) which I call it as 'image_instruction' and is part of the observation. I think I would still use the same architecture as I have while removing the keys in the task_stack_keys. Thanks for clarifying! This was really useful!

zwbx commented 7 months ago

Useful, marked