ManifoldRG / NEKO

In Progress Implementation of GATO style Generalist Multimodal model capable of image, text, RL and Robotics tasks
https://discord.gg/brsPnzNd8h
GNU General Public License v3.0
43 stars 10 forks source link

General Image encoding #9

Open daniellawson9999 opened 1 year ago

daniellawson9999 commented 1 year ago

This issue relates to how we embed each (16x16) patch. Additionally, we discuss the positional encodings we add to each patch's embedding.

Patch Embedding

Let's review, we split the images into patches and embed each patch: https://github.com/ManifoldRG/gato-control/blob/2deb510246ebd6b13dd53199f8de7df4e0b96f34/gato/policy/embeddings.py#L43-L47

Each patch is embedded using a single ResNet style block: https://github.com/ManifoldRG/gato-control/blob/2deb510246ebd6b13dd53199f8de7df4e0b96f34/gato/policy/embeddings.py#L111-L131

A task is that we should do: 1) Review if this architecture makes sense, and assess variations on image-based tasks (e.g. Atari)

Positional Embeddings

First, we aim to follow the methodology in C.3 Position Encodings, at: https://github.com/ManifoldRG/gato-control/blob/2deb510246ebd6b13dd53199f8de7df4e0b96f34/gato/policy/embeddings.py#L63-L74

This is a rather complicated form of handling patch positional encodings, including as sampling random position indices within certain intervals for each patch position during training, but using the mean during inference. This general proposed strategy is potentially helps the model train on images of varying sizes well, as Gato does not need to resize images, and can handle images at their native size as long as it is divisible by patch_size (16). Our current implementation pad images for them to be divisible by 16 but does not resize, e.g. (77x77 -> 80x80, with 3 cols/rows of padding). However, most ViTs use more simple strategies after embedding patches. They add a learned positional encoding to each patch. This means you take all your image patches, flatten them into a 1-dimensional ordering, and then add a learned position encoding depending on the position. If we follow the instructions in Gato's appendix, this actually happens in addition the randomized patch position strategy. This is because we add learned positional emebddings to each entry in the observation token sequence. Thus, if we disable applying this strategy, we still get positional information from casual attention masking, and the learned positional encodings added to each observation:

https://github.com/ManifoldRG/gato-control/blob/2deb510246ebd6b13dd53199f8de7df4e0b96f34/gato/policy/gato_policy.py#L361-L365

Thus, we have actually trained with --disable_patch_pos_encoding for initial Atari testing, which sets use_pos_encoding=False, just within the ImageEmbedding, causing patch positional information to solely come from the observation positional encodings (disable_inner_pos_encoding is not enabled) and casual attention, rather than Gato's

https://github.com/ManifoldRG/gato-control/blob/2deb510246ebd6b13dd53199f8de7df4e0b96f34/gato/policy/embeddings.py#L55C1-L58C1

Thus, a task is to: 1) review our current implementation of the full gato-style randomized positional encoding and perform comparison runs against the simplified ViT style positional encodings. We will likely not see benefit of using Gato's full strategy when only assessing images of one size ,or without evaluating generalization

2) then, we need to decide which strategy we aim to prioritize for other image experiments. If we go with another strategy such as using VQVAE as next discussed, then the ViT style learned positional embedding strategy should also be good, and we do not need to do use this.

Assess alternative formulations for embedding images

Particularly, we are interested in exploring VQVAEs, following a similar methodology to recent work like RoboCat https://www.deepmind.com/blog/robocat-a-self-improving-robotic-agent. First you pretrain your VQVAE on your image data, and then use your frozen VQVAE encoder tokenize/embed images. After encoding, each image is represented by several discrete tokens, where each discrete token corresponds to a continuous vector in the learned code book. Where we only have a sequence of continuous "vectors" when doing ResNet style patch embedding, we both get the continuous representations of the VQVAE after plugging in tokens to the codebook before input to the transformer, but we can also predict the discrete tokens of images. This allows our model to potentially conditionally generate images, video, or plan in image space. We may also see our main transformer train faster, or other benefits discussed in RoboCat. We would need to implement pretraining the VQVAE, and may benefit from pretraining on diverse data + in domain data (control tasks like Atari), e.g. ImageNet,as performed in RoboCat. It may also be worthwhile to see if pretrained VQVAEs from opensource VQGAN implementations or others could work out of the box, or with minimal fine-tuning: https://github.com/CompVis/taming-transformers

daniellawson9999 commented 1 year ago

look into this: https://github.com/AILab-CVC/SEED https://arxiv.org/abs/2307.08041