Closed truncs closed 10 months ago
Thanks for the question! We use task_stack_keys
as a mechanism to do goal-image conditioning.
The image tokenizer roughly implements the following logic:
inputs = jnp.concatenate(
[observations[k] for k in obs_stack_keys] +
[tasks[k] for k in task_stack_keys],
axis=-1
)
tokens = encoder(inputs)
So, when you configure the tokenizer this way
"primary": ModuleSpec.create(
ImageTokenizer,
obs_stack_keys=["image_primary"],
task_stack_keys=["image_primary"],
encoder=ModuleSpec.create(SmallStem16),
),
Inside the tokenizer, the "image_primary" key is extracted from the "observations" dictionary, the "image_primary" key is extracted from the tasks dictionary, and the two are concatenated channel-wise, before being passed into the conv layers. This is known as early-goal fusion, and means that from the very beginning of the network, the model can do pixel-wise comparisons between the camera view at the current timestep and the desired goal camera view (a typically useful inductive bias for goal-reaching tasks).
If you don't care about goal-image task conditioning (e.g. you only want language-conditioned training), then you should simply omit the task_stack_keys
argument (same if you want to do goal-image conditioning, but would prefer to separately encode / tokenized the goal image and the current observation).
In any case, what is happening in your current code is that the config is expecting a goal image corresponding to "image_primary" in tasks["image_primary"], is not finding it in the tasks dictionary, and choosing to just insert a black image in its place (effectively a no-op).
I actually don't have a goal image but instead I have a image trajectory prompt (similar to https://arxiv.org/abs/2311.01977) which I call it as 'image_instruction' and is part of the observation. I think I would still use the same architecture as I have while removing the keys in the task_stack_keys. Thanks for clarifying! This was really useful!
Useful, marked
I am trying to pretrain on a dataset and my intended use case is to have three images tokenized as inputs to the transformers and an action head with 2 outputs. When I run the script I do see that this is indeed the case -
But I get these as INFO messages
My input config for observation is the following
Going through the image tokenizer code seems like there is obs_stack_keys in the case I want to stack the input? And then there is task input which I am not sure what is it meant for? Am I doing this in the right way?