lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.93k stars 253 forks source link

Add Masked Autoencoder implementation #721

Closed IgorSusmelj closed 2 years ago

IgorSusmelj commented 2 years ago

The paper Masked Autoencoders Are Scalable Vision Learners https://arxiv.org/abs/2111.06377 is suggesting that a masked auto-encoder (similar to pre-training on NLP) works very well as a pretext task for self-supervised learning. Let's add it to Lightly.

image

Atharva-Phatak commented 2 years ago

I would like to work on this, do we have reference code implementations for this ? I am not sure if I will be able to reproduce the results as I do not have that much hardware, but this seems interesting.

philippmwirth commented 2 years ago

You can check out papers with code. They are referencing this repo by facebook research.

It'd be great if you could share your thoughts on how you would integrate it best into the current lightly package structure 🙂

Atharva-Phatak commented 2 years ago

@philippmwirth Thanks for the references. I will give the paper a read and then we can discuss ideas how we can integrate it with lightly. This will be fun to do.

Atharva-Phatak commented 2 years ago

I looked at the code. It seems simple enough. Few things I would like to highlight.

All in all, this seems like a nice implementation for lightly. I can implement it and bring it to the lightly code standards, what I would require help from lightly team is for experimentation(I do not have that much hardware to replicate the results) and I need guidance on how to write tests for this implementation.

Please let me know your thought @philippmwirth @IgorSusmelj.

IgorSusmelj commented 2 years ago

Hi @Atharva-Phatak, thanks for the summary!

That looks great.

Augmentations should be used from torchvision whenever possible. Model Architecture I'd try to avoid adding timm as a dependency. Any thoughts @guarin and @philippmwirth ? Visualization-Utils we can add this later or one could use the colab from facebook :) Criterion/Loss Function this should be separate from the model

Overall, we try to make lightly rather modular. That will make it easier to combine different architectures, training procedures, and loss functions. I guess the key question is how to split up theMaskedAutoencoderViT into good independent pieces.

Btw. torchvision just added lots of new augmentations and vit models. Maybe we could build on top of it?

IgorSusmelj commented 2 years ago

We can test the whole implementation on our hardware.

Atharva-Phatak commented 2 years ago

@IgorSusmelj We can adapt the code from timm it seems pretty easy to integrate that. This will remove additional dependency of timm.

philippmwirth commented 2 years ago

Please correct me if I'm wrong but wouldn't it be enough to e.g. inherit from the torchvision ViT implementation (link) and simply override the forward function. Something like this:

MaskedAutoencoderViT(torchvision.models.VisionTransformer):

    def forward(self, x: torch.Tensor, mask_ratio: float):

        x = self._process_input(x)
        n = x.shape[0]
        # new: random masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x, mask, ids_restore

And then we could add modules for the decoder and for the loss to lightly. Sharing the architecture with torchvision has the simple advantage that exporting and importing weights is very convenient.

@guarin I think you looked into this before. Anything I'm missing?

guarin commented 2 years ago

I remember that the masking + positional encoding part was not trivial. Especially because it is needed in the encoder and the decoder. And it did not fit nicely our workflow where we always return (image, target, filename) tuples. But I am sure we can figure it out.

I would also use the pytorch vit and first focus on implementing the model without any visualizations.

Atharva-Phatak commented 2 years ago

@philippmwirth I agree torchvision version looks simple enough. If everyone concurs then I will study the code in torchvision and we can implement MAE accordingly.

Please let me know.

philippmwirth commented 2 years ago

Sounds great @Atharva-Phatak! Let us know if you need support 🙂

Atharva-Phatak commented 2 years ago

I am sorry this is taking time from my end as I am busy with my final exams 😭 . I will try to create a PR ASAP.

philippmwirth commented 2 years ago

No worries, good luck with your exams!

philippmwirth commented 2 years ago

I've been doing some investigations and somehow this is the best I've come up with on how to make the torchvision ViTs work with Lightly and the MAE setup. @guarin @Atharva-Phatak @IgorSusmelj I'd love to hear your opinions. IMO it's not a very clean approach but I think it should work... If you have better ideas let me know 🙂

Possible approach for an MAE implementation (or rather how to pretrain a torchvision.models.VisionTransformer with MAE:


# initialize ViT
vit = torchvision.models.vit_b_16(pretrained=False)

# use a lightly MaskedEncoder which inherits from torchvision.models.vision_transformer.Encoder
encoder = lightly.modules.MaskedEncoder.from_encoder(vit.encoder)

# use a lightly MaskedDecoder
decoder = lightly.MaskedDecoder()

# use the loss implemented by lightly
loss = MAELoss()

# pre-training
for i in range(epochs):
    for x in dataloader:
        # x is a batch of images (bsz, 3, w, h)

        # need to process the input (patchify & embed)
        x_processed = vit._process_input(x)

        # manually add the cls token
        n = x_processed.shape[0]
        batch_class_token = self.class_token.expand(n, -1, -1)
        x _processed= torch.cat([batch_class_token, x_processed], dim=1)

        # forward pass encoder
        x_encoded, mask, ids_restore = encoder(x_processed)

        # forward pass decoder
        x_decoded = decoder(x_encoded, mask, ids_restore)

        # possibly convert x_decoded to image patches here

        # loss calculation
        l = loss(x, x_decoded)
        # backwards pass etc

# restore original encoder with pretrained weights
vit.encoder = encoder.strip_mask()

This would require us to add the following things to lightly:

guarin commented 2 years ago

This looks great!

Based on your proposal I was able to write the following draft that successfully runs, not sure if it actually works though 🙂

The code is pretty verbose as I have not yet figured out an optimal structure but all the building blocks are there. The main issue is that encoding, masking, and decoding are pretty interleaved and have to share a lot of information between each other. So maybe an overall MAE class that holds the encoder, decoder, class token, and mask token could be a good solution, although this would be a bit against our "low-level" building blocks principle.

This code should also be pretty easy to adapt to the SimMIM and SplitMask models.

The code is adapted from: https://github.com/facebookresearch/mae

from typing import Optional

import torch
import torchvision
import lightly
import tqdm

def repeat_token_like(token, input):
    # repeats token to have same shape as input
    N, S, _ = input.shape
    return token.repeat(N, S, 1)

def expand_index_like(idx, input):
    # expands the index along the feature dimension of input
    # returns idx with shape (N_idx, S_idx, D_input)
    D = input.shape[-1]
    idx = idx.unsqueeze(-1).expand(-1, -1, D)
    return idx

def get_at_index(input, idx):
    # gets tokens at index
    idx = expand_index_like(idx, input)
    return torch.gather(input, 1, idx)

def set_at_index(input, idx, value):
    # sets tokens at index to value
    idx = expand_index_like(idx, input)
    return torch.scatter(input, 1, idx, value)

def prepend_class_token(input, class_token):
    # prepends class token to input
    N = input.shape[0]
    batch_class_token = class_token.expand(N, -1, -1)
    return torch.cat([batch_class_token, input], dim=1)

def create_random_mask(input, mask_ratio=0.6):
    # creates random masks for input
    # returns idx_keep, idx_mask tuple
    # idx_keep has shape (N, num_keep)
    # idx_mask has shape (N, S - num_keep)

    # S = sequence length
    N, S, _ = input.shape
    num_keep = int(S * (1 - mask_ratio))

    noise = torch.rand(N, S, device=input.device)
    # make sure that class token is not masked
    noise[:, 0] = -1

    # get indices of tokens to keep
    indices = torch.argsort(noise, dim=1)
    idx_keep = indices[:, :num_keep]
    idx_mask = indices[:, num_keep:]

    return idx_keep, idx_mask

def patchify(imgs, patch_size):
    # converts images into patches
    # output has shape (N, num_patches, patch_size ** 2 * C)
    N, C, H, W = imgs.shape
    assert H == W and H % patch_size == 0

    patch_h = patch_w = H // patch_size
    num_patches = patch_h * patch_w
    patches = imgs.reshape(shape=(N, C, patch_h, patch_size, patch_w, patch_size))
    patches = torch.einsum('nchpwq->nhwpqc', patches)
    patches = patches.reshape(shape=(N, num_patches, patch_size ** 2 * C))
    return patches

class MAEEncoder(torchvision.models.vision_transformer.Encoder):        

    def forward(self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        if idx_keep is not None:
            input = get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))

    @classmethod
    def from_vit_encoder(cls, vit_encoder):
        encoder = cls(
            seq_length=1,
            num_layers=1,
            num_heads=1,
            hidden_dim=1,
            mlp_dim=1,
            dropout=0,
            attention_dropout=0,
        )
        encoder.pos_embedding = vit_encoder.pos_embedding
        encoder.dropout = vit_encoder.dropout
        encoder.layers = vit_encoder.layers
        encoder.ln = vit_encoder.ln
        return encoder

class MAEDecoder(torchvision.models.vision_transformer.Encoder):
    def __init__(
        self, 
        embed_input_dim, 
        patch_size,
        hidden_dim,
        **kwargs,
    ):
        super().__init__(hidden_dim=hidden_dim, **kwargs)
        self.decoder_embed = torch.nn.Linear(embed_input_dim, hidden_dim, bias=True)
        self.prediction_head = torch.nn.Linear(decoder_dim, patch_size ** 2 * 3)

    def forward(self, input):
        return self.decode(input)

    def embed(self, input):
        return self.decoder_embed(input)

    def decode(self, input):
        return super().forward(input)

    def predict(self, input):
        return self.prediction_head(input)

vit = torchvision.models.vit_b_32(pretrained=True)

decoder_dim = 512
class_token = vit.class_token
mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_dim))

encoder = MAEEncoder.from_vit_encoder(vit.encoder)
decoder = MAEDecoder(
    embed_input_dim=vit.hidden_dim,
    patch_size=vit.patch_size,
    seq_length=vit.seq_length,
    num_layers=1,
    num_heads=4,
    hidden_dim=decoder_dim,
    mlp_dim=decoder_dim * 4,
    dropout=0,
    attention_dropout=0,
)

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop((vit.image_size, vit.image_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        lightly.data.collate.imagenet_normalize['mean'],
        lightly.data.collate.imagenet_normalize['std'],
    )
])

dataset = lightly.data.LightlyDataset('/datasets/aquarium', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=4, drop_last=True)

optimizer = torch.optim.Adam(
    params=(
        [class_token, mask_token]
        + list(encoder.parameters())
        + list(decoder.parameters())
    ),
    lr=0.06,
)
criterion = torch.nn.MSELoss()

# pre-training
for epoch in range(10):
    epoch_loss = 0
    for imgs, targets, filenames in tqdm.tqdm(dataloader):
        # imgs is a batch of images (bsz, 3, w, h)

        # need to process the input (patchify & embed)
        x_processed = vit._process_input(imgs)

        # add the cls token
        x_processed = prepend_class_token(x_processed, class_token)

        # get mask indices
        idx_keep, idx_mask = create_random_mask(x_processed)

        # forward pass encoder, only non-masked tokens are encoded
        x_encoded_keep = encoder(x_processed, idx_keep)

        # project to decoder input dimension
        x_decode_embed_keep = decoder.embed(x_encoded_keep)

        # build masked decoder input
        # masked tokens are set to the mask_token
        # non-masked tokens are set to the embedded encoder tokens
        x_masked = repeat_token_like(mask_token, x_processed)
        x_masked = set_at_index(x_masked, idx_keep, x_decode_embed_keep)

        # forward pass decoder
        x_decoded = decoder(x_masked)

        # predict pixel values for masked tokens
        x_pred = get_at_index(x_decoded, idx_mask)
        x_pred = decoder.predict(x_pred)

        # get image patches for masked tokens
        # must adjust idx_mask for missing class token
        patches = patchify(imgs, vit.patch_size)
        target = get_at_index(patches, idx_mask - 1)

        loss = criterion(x_pred, target)

        # backwards pass etc
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.detach()

    print(epoch, epoch_loss)
Atharva-Phatak commented 2 years ago

@guarin I was implementing Encoder and Decoder as different classes similar to heads and then I was going to combine them in a a single class so that the information sharing will be a bit easier between the modules. I think @guarin seems more elegant and easy to adapt. But that being said I would like to know which approach should I follow ?

philippmwirth commented 2 years ago

Great draft @guarin! 🙂

To answer both your questions: I'd suggest we start building the low-level blocks first (i.e. MAEEncoder and MAEDecoder) similar to what @guarin used above. We can always add a high-level interface which connects the two later. For example, we can add lightly.models.modules.encoders.MAEEncoder and lightly.models.modules.encoders.MAEDecoder in a first step and then later (if necessary) we'll work on lightly.models.mae.MAE.

Ideally, we'd have a working version of the encoder and decoder relatively soon so we can run a quick benchmark on e.g. Imagenette to see if it works as expected. We can then work on the final implementation together 👍

There might be some differences between the original paper and your implementation, @guarin:

That's not dramatic but we should make sure to note it somewhere.

guarin commented 2 years ago

MAE uses sine-cosine positional embeddings while torchvision uses learned ones (I believe)

Aaah good catch, I didn't notice that! Yes we either have to add a note or can overwrite the positional embedding with a sine-cosine one. Although overwriting would break pretrained vits, so maybe that is not the best idea.

Regarding cleanup / code structure:

Next steps would be:

@Atharva-Phatak do you already have some example code/draft? Would be great if we can compare :)

Atharva-Phatak commented 2 years ago

Hi @guarin I am adding the implementation of encoder. My decoder is same is as yours. So all in all @guarin we should move ahead with your structure, the only thing is we need to structure the helper utilities very properly.


from utils import random_masking
class MaskedEncoderVIT(torchvision.models.VisionTransformer):

    def forward(self, x: torch.Tensor, mask_ratio: float):

        x = self._process_input(x)
        n = x.shape[0]

        # new: random masking
        x, mask, ids_to_restore = random_masking(x, mask_ratio)  #added random masking in utils.py file

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)
        x = x[:, 0]
        x = self.heads(x)
        return x, mask, ids_to_restore