Closed IgorSusmelj closed 2 years ago
I would like to work on this, do we have reference code implementations for this ? I am not sure if I will be able to reproduce the results as I do not have that much hardware, but this seems interesting.
You can check out papers with code. They are referencing this repo by facebook research.
It'd be great if you could share your thoughts on how you would integrate it best into the current lightly package structure 🙂
@philippmwirth Thanks for the references. I will give the paper a read and then we can discuss ideas how we can integrate it with lightly. This will be fun to do.
I looked at the code. It seems simple enough. Few things I would like to highlight.
All in all, this seems like a nice implementation for lightly. I can implement it and bring it to the lightly code standards, what I would require help from lightly team is for experimentation(I do not have that much hardware to replicate the results) and I need guidance on how to write tests for this implementation.
Please let me know your thought @philippmwirth @IgorSusmelj.
Hi @Atharva-Phatak, thanks for the summary!
That looks great.
Augmentations should be used from torchvision whenever possible. Model Architecture I'd try to avoid adding timm as a dependency. Any thoughts @guarin and @philippmwirth ? Visualization-Utils we can add this later or one could use the colab from facebook :) Criterion/Loss Function this should be separate from the model
Overall, we try to make lightly rather modular. That will make it easier to combine different architectures, training procedures, and loss functions.
I guess the key question is how to split up theMaskedAutoencoderViT
into good independent pieces.
Btw. torchvision just added lots of new augmentations and vit models. Maybe we could build on top of it?
We can test the whole implementation on our hardware.
@IgorSusmelj We can adapt the code from timm
it seems pretty easy to integrate that. This will remove additional dependency of timm.
Please correct me if I'm wrong but wouldn't it be enough to e.g. inherit from the torchvision ViT
implementation (link) and simply override the forward
function. Something like this:
MaskedAutoencoderViT(torchvision.models.VisionTransformer):
def forward(self, x: torch.Tensor, mask_ratio: float):
x = self._process_input(x)
n = x.shape[0]
# new: random masking
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x, mask, ids_restore
And then we could add modules for the decoder and for the loss to lightly. Sharing the architecture with torchvision has the simple advantage that exporting and importing weights is very convenient.
@guarin I think you looked into this before. Anything I'm missing?
I remember that the masking + positional encoding part was not trivial. Especially because it is needed in the encoder and the decoder. And it did not fit nicely our workflow where we always return (image, target, filename) tuples. But I am sure we can figure it out.
I would also use the pytorch vit and first focus on implementing the model without any visualizations.
@philippmwirth I agree torchvision version looks simple enough. If everyone concurs then I will study the code in torchvision and we can implement MAE accordingly.
Please let me know.
Sounds great @Atharva-Phatak! Let us know if you need support 🙂
I am sorry this is taking time from my end as I am busy with my final exams 😭 . I will try to create a PR ASAP.
No worries, good luck with your exams!
I've been doing some investigations and somehow this is the best I've come up with on how to make the torchvision ViTs work with Lightly and the MAE setup. @guarin @Atharva-Phatak @IgorSusmelj I'd love to hear your opinions. IMO it's not a very clean approach but I think it should work... If you have better ideas let me know 🙂
Possible approach for an MAE implementation (or rather how to pretrain a torchvision.models.VisionTransformer
with MAE:
# initialize ViT
vit = torchvision.models.vit_b_16(pretrained=False)
# use a lightly MaskedEncoder which inherits from torchvision.models.vision_transformer.Encoder
encoder = lightly.modules.MaskedEncoder.from_encoder(vit.encoder)
# use a lightly MaskedDecoder
decoder = lightly.MaskedDecoder()
# use the loss implemented by lightly
loss = MAELoss()
# pre-training
for i in range(epochs):
for x in dataloader:
# x is a batch of images (bsz, 3, w, h)
# need to process the input (patchify & embed)
x_processed = vit._process_input(x)
# manually add the cls token
n = x_processed.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x _processed= torch.cat([batch_class_token, x_processed], dim=1)
# forward pass encoder
x_encoded, mask, ids_restore = encoder(x_processed)
# forward pass decoder
x_decoded = decoder(x_encoded, mask, ids_restore)
# possibly convert x_decoded to image patches here
# loss calculation
l = loss(x, x_decoded)
# backwards pass etc
# restore original encoder with pretrained weights
vit.encoder = encoder.strip_mask()
This would require us to add the following things to lightly:
MaskedEncoder
MaskedDecoder
MAELoss
This looks great!
Based on your proposal I was able to write the following draft that successfully runs, not sure if it actually works though 🙂
The code is pretty verbose as I have not yet figured out an optimal structure but all the building blocks are there. The main issue is that encoding, masking, and decoding are pretty interleaved and have to share a lot of information between each other. So maybe an overall MAE class that holds the encoder, decoder, class token, and mask token could be a good solution, although this would be a bit against our "low-level" building blocks principle.
This code should also be pretty easy to adapt to the SimMIM and SplitMask models.
The code is adapted from: https://github.com/facebookresearch/mae
from typing import Optional
import torch
import torchvision
import lightly
import tqdm
def repeat_token_like(token, input):
# repeats token to have same shape as input
N, S, _ = input.shape
return token.repeat(N, S, 1)
def expand_index_like(idx, input):
# expands the index along the feature dimension of input
# returns idx with shape (N_idx, S_idx, D_input)
D = input.shape[-1]
idx = idx.unsqueeze(-1).expand(-1, -1, D)
return idx
def get_at_index(input, idx):
# gets tokens at index
idx = expand_index_like(idx, input)
return torch.gather(input, 1, idx)
def set_at_index(input, idx, value):
# sets tokens at index to value
idx = expand_index_like(idx, input)
return torch.scatter(input, 1, idx, value)
def prepend_class_token(input, class_token):
# prepends class token to input
N = input.shape[0]
batch_class_token = class_token.expand(N, -1, -1)
return torch.cat([batch_class_token, input], dim=1)
def create_random_mask(input, mask_ratio=0.6):
# creates random masks for input
# returns idx_keep, idx_mask tuple
# idx_keep has shape (N, num_keep)
# idx_mask has shape (N, S - num_keep)
# S = sequence length
N, S, _ = input.shape
num_keep = int(S * (1 - mask_ratio))
noise = torch.rand(N, S, device=input.device)
# make sure that class token is not masked
noise[:, 0] = -1
# get indices of tokens to keep
indices = torch.argsort(noise, dim=1)
idx_keep = indices[:, :num_keep]
idx_mask = indices[:, num_keep:]
return idx_keep, idx_mask
def patchify(imgs, patch_size):
# converts images into patches
# output has shape (N, num_patches, patch_size ** 2 * C)
N, C, H, W = imgs.shape
assert H == W and H % patch_size == 0
patch_h = patch_w = H // patch_size
num_patches = patch_h * patch_w
patches = imgs.reshape(shape=(N, C, patch_h, patch_size, patch_w, patch_size))
patches = torch.einsum('nchpwq->nhwpqc', patches)
patches = patches.reshape(shape=(N, num_patches, patch_size ** 2 * C))
return patches
class MAEEncoder(torchvision.models.vision_transformer.Encoder):
def forward(self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
if idx_keep is not None:
input = get_at_index(input, idx_keep)
return self.ln(self.layers(self.dropout(input)))
@classmethod
def from_vit_encoder(cls, vit_encoder):
encoder = cls(
seq_length=1,
num_layers=1,
num_heads=1,
hidden_dim=1,
mlp_dim=1,
dropout=0,
attention_dropout=0,
)
encoder.pos_embedding = vit_encoder.pos_embedding
encoder.dropout = vit_encoder.dropout
encoder.layers = vit_encoder.layers
encoder.ln = vit_encoder.ln
return encoder
class MAEDecoder(torchvision.models.vision_transformer.Encoder):
def __init__(
self,
embed_input_dim,
patch_size,
hidden_dim,
**kwargs,
):
super().__init__(hidden_dim=hidden_dim, **kwargs)
self.decoder_embed = torch.nn.Linear(embed_input_dim, hidden_dim, bias=True)
self.prediction_head = torch.nn.Linear(decoder_dim, patch_size ** 2 * 3)
def forward(self, input):
return self.decode(input)
def embed(self, input):
return self.decoder_embed(input)
def decode(self, input):
return super().forward(input)
def predict(self, input):
return self.prediction_head(input)
vit = torchvision.models.vit_b_32(pretrained=True)
decoder_dim = 512
class_token = vit.class_token
mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_dim))
encoder = MAEEncoder.from_vit_encoder(vit.encoder)
decoder = MAEDecoder(
embed_input_dim=vit.hidden_dim,
patch_size=vit.patch_size,
seq_length=vit.seq_length,
num_layers=1,
num_heads=4,
hidden_dim=decoder_dim,
mlp_dim=decoder_dim * 4,
dropout=0,
attention_dropout=0,
)
transform = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop((vit.image_size, vit.image_size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
lightly.data.collate.imagenet_normalize['mean'],
lightly.data.collate.imagenet_normalize['std'],
)
])
dataset = lightly.data.LightlyDataset('/datasets/aquarium', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=4, drop_last=True)
optimizer = torch.optim.Adam(
params=(
[class_token, mask_token]
+ list(encoder.parameters())
+ list(decoder.parameters())
),
lr=0.06,
)
criterion = torch.nn.MSELoss()
# pre-training
for epoch in range(10):
epoch_loss = 0
for imgs, targets, filenames in tqdm.tqdm(dataloader):
# imgs is a batch of images (bsz, 3, w, h)
# need to process the input (patchify & embed)
x_processed = vit._process_input(imgs)
# add the cls token
x_processed = prepend_class_token(x_processed, class_token)
# get mask indices
idx_keep, idx_mask = create_random_mask(x_processed)
# forward pass encoder, only non-masked tokens are encoded
x_encoded_keep = encoder(x_processed, idx_keep)
# project to decoder input dimension
x_decode_embed_keep = decoder.embed(x_encoded_keep)
# build masked decoder input
# masked tokens are set to the mask_token
# non-masked tokens are set to the embedded encoder tokens
x_masked = repeat_token_like(mask_token, x_processed)
x_masked = set_at_index(x_masked, idx_keep, x_decode_embed_keep)
# forward pass decoder
x_decoded = decoder(x_masked)
# predict pixel values for masked tokens
x_pred = get_at_index(x_decoded, idx_mask)
x_pred = decoder.predict(x_pred)
# get image patches for masked tokens
# must adjust idx_mask for missing class token
patches = patchify(imgs, vit.patch_size)
target = get_at_index(patches, idx_mask - 1)
loss = criterion(x_pred, target)
# backwards pass etc
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.detach()
print(epoch, epoch_loss)
@guarin I was implementing Encoder and Decoder as different classes similar to heads and then I was going to combine them in a a single class so that the information sharing will be a bit easier between the modules. I think @guarin seems more elegant and easy to adapt. But that being said I would like to know which approach should I follow ?
Great draft @guarin! 🙂
To answer both your questions: I'd suggest we start building the low-level blocks first (i.e. MAEEncoder
and MAEDecoder
) similar to what @guarin used above. We can always add a high-level interface which connects the two later. For example, we can add lightly.models.modules.encoders.MAEEncoder
and lightly.models.modules.encoders.MAEDecoder
in a first step and then later (if necessary) we'll work on lightly.models.mae.MAE
.
Ideally, we'd have a working version of the encoder and decoder relatively soon so we can run a quick benchmark on e.g. Imagenette to see if it works as expected. We can then work on the final implementation together 👍
There might be some differences between the original paper and your implementation, @guarin:
That's not dramatic but we should make sure to note it somewhere.
MAE uses sine-cosine positional embeddings while torchvision uses learned ones (I believe)
Aaah good catch, I didn't notice that! Yes we either have to add a note or can overwrite the positional embedding with a sine-cosine one. Although overwriting would break pretrained vits, so maybe that is not the best idea.
Regarding cleanup / code structure:
Next steps would be:
@Atharva-Phatak do you already have some example code/draft? Would be great if we can compare :)
Hi @guarin I am adding the implementation of encoder. My decoder is same is as yours. So all in all @guarin we should move ahead with your structure, the only thing is we need to structure the helper utilities very properly.
from utils import random_masking
class MaskedEncoderVIT(torchvision.models.VisionTransformer):
def forward(self, x: torch.Tensor, mask_ratio: float):
x = self._process_input(x)
n = x.shape[0]
# new: random masking
x, mask, ids_to_restore = random_masking(x, mask_ratio) #added random masking in utils.py file
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
x = x[:, 0]
x = self.heads(x)
return x, mask, ids_to_restore
The paper Masked Autoencoders Are Scalable Vision Learners https://arxiv.org/abs/2111.06377 is suggesting that a masked auto-encoder (similar to pre-training on NLP) works very well as a pretext task for self-supervised learning. Let's add it to Lightly.