lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
3.14k stars 277 forks source link

Add MaskedVisionTransformerDecoder #1615

Open guarin opened 2 months ago

guarin commented 2 months ago

We already have MaskedVisionTransformer classes for TIMM and torchvision that take images as input and output a token for every image patch. These classes are versatile as they can be used for many different methods (I-JEPA, DINOv2, MAE, etc.). Some of these methods also require a ViT decoder (I-JEPA and MAE) which we currently implement independently for every method (we have IJEPAPredictor and MAEDecoder). These decoders have very similar structures and I think we could cover them in a single MaskedVisionTransformerDecoder class.

The class would have the following interface:

class MaskedVisionTransformer(Module):
    def __init__(
        self,
        embed_dim: int,
        depth: int,
        # ... other args from VisionTransformer
    )
        self.mask_token = ...
        self.pos_embed = ...
        self.blocks = ...
        self.norm = ...

    def forward(
        self,
        x: Tensor,
        idx_keep: Tensor,
        idx_mask: Tensor,
        mask: Tensor,
    ) -> Tensor:
        x = self.preprocess(x, idx_keep, idx_mask, mask)
        x = self.blocks(x)
        x = self.norm(x)
        return x

    def preprocess(
        self,
        x: Tensor,
        idx_keep: Tensor,
        idx_mask: Tensor,
        mask: Tensor,
    ) -> Tensor:
        # Same as in MaskedVisionTransformer.preprocess. Add positional embedding
        # and apply the masking.

This would allow us to implement MAE like this:


class MAE(Module)
    def __init__(self):
        self.encoder = MaskedVisionTransformerTIMM(...)
        self.embed = Linear(...)
        self.decoder = MaskedVisionTransformerDecoderTIMM(...)
        self.prediction_head = Linear(...)

    def forward(self, images: Tensor):
        idx_keep, idx_mask = utils.random_token_mask(...)
        x = self.encoder(images, idx_keep)
        x = self.embed(x)

        x_decode = x.new_zeros(x.size())
        x_decode = utils.set_at_index(x_decode, idx_keep, x)
        x_decode = self.decoder(x_decode, idx_mask)

        x_pred = utils.get_at_index(x_decode, idx_mask)
        x_pred = self.prediction_head(x_pred)
        return x_pred

And I-JEPA like this:

class IJEPA(Module):
    def __init__(self):
        self.encoder = MaskedVisionTransformerTIMM(...)
        self.embed = Linear(...)
        self.decoder = MaskedVisionTransformerDecoderTIMM(...)
        self.prediction_head = Linear(...)
        self.target_encoder = copy.deepcopy(self.encoder)

    def forward_context(self, images: Tensor, mask_enc: Tensor, mask_pred: Tensor):
        x = self.encoder(images, idx_keep=mask_enc)
        x = self.embed(x)

        x_decode = x.new_zeros(...)
        x_decode = utils.set_at_index(x_decode, mask_enc, x)
        x_decode = x_decode.repeat_interleave(len(mask_pred), dim=0)
        x_decode = self.decoder(x_decode, idx_mask=mask_pred, idx_keep=mask_enc | mask_pred)
        x_decode = ... # select only mask_pred tokens
        x_pred = self.prediction_head(x_decode)
        return x_pred

    @torch.no_grad()
    def forward_target(self, images: Tensor, mask_enc: Tensor, mask_pred: Tensor):
        x = self.target_encoder(images)
        x = x.repeat_interleave(len(mask_pred), dim=0)
        x = utils.get_at_index(x, mask_pred)
        return x

By sharing the MaskedVisionTransformerDecoder class we can deduplicate a lot of the code around positional embeddings, transformer blocks, and masking. And by moving the embed and prediction_head layers out of the decoder, the decoder class becomes more modular and easier to reuse. We cannot re-use MaskedVisionTransformer directly because it expects images instead of tokens as input.

I am not yet 100% sure that this is possible. Especially for I-JEPA there is some funky masking logic that might be hard to generalize. Could be worth a try though.

Shrinidhibhat87 commented 1 week ago

@guarin I would be open to try working on this feature if I understand the requirements more specifically. I really like the open-source initiative at Lightly and would like to contribute to this repo more.

So here are my questions: 1) You want a more general abstract class for the decoder which can then be used for MAE and I-JEPA accordingly. Am i correct to understand this? 2) Considering the fact that I-JEPA does do more with its masking logic, we can simply override the abstract forward method for this accordingly. But then again, we wont simply have one object that works for both MAE and I-JEPA. Is this understanding of mine correct?

Based on the repository, I feel like more can be done with restructuring. 1) Firstly, since the software design is more aligned with the Factory based pattern, it would be beneficial to place the abstract methods separately rather than put them all in the modules subdirectory. 2) It would also be beneficial if we had a separate subdirectory for torchvision and timm models. That way there is a clear separation rather than relying on the suffixes. This would also help when scaling up the repo and lightly in general.

I can then start working on a PR based on your response. Since this is still in an ideation phase, I would start with a draft PR before moving onto a full blown PR with reviews.

Please do tell me what you think.

guarin commented 1 week ago

Hi!

You want a more general abstract class for the decoder which can then be used for MAE and I-JEPA accordingly. Am i correct to understand this?

Yes, the idea was if we can write a single MaskedVisionTransformerDecoder class and reuse it for the MAEDecoder and IJEPAPredictor. The relevant files are here:

We would then also move the prediction part out of the MaskedVisionTransformer and have individual MAEPrediction and IJEPAPrediction heads.

Considering the fact that I-JEPA does do more with its masking logic, we can simply override the abstract forward method for this accordingly. But then again, we wont simply have one object that works for both MAE and I-JEPA. Is this understanding of mine correct?

Yes exactly, that is the concern. But TBH this was just an idea and I didn't really have time to investigate this further. Maybe there is a nice way to do it or maybe it is better if we just have two separate implementations.

Firstly, since the software design is more aligned with the Factory based pattern, it would be beneficial to place the abstract methods separately rather than put them all in the modules subdirectory.

Could you explain this further? I don't fully understand what you mean.

It would also be beneficial if we had a separate subdirectory for torchvision and timm models. That way there is a clear separation rather than relying on the suffixes. This would also help when scaling up the repo and lightly in general.

We have plans to refactor the package quite a bit, will take this into account as well.

I can then start working on a PR based on your response. Since this is still in an ideation phase, I would start with a draft PR before moving onto a full blown PR with reviews.

If you could just have a look whether it is possible to have a shared MaskedVisionTransformerDecoder implementation for MAE and IJEPA that would already be awesome! It will for sure take a bit more time to flesh out the details. There is also a high probability that it doesn't work well and we have to keep the separate implementations.

Shrinidhibhat87 commented 6 days ago

Hey @guarin thanks for the swift response.

I will take a look into the feasibility of the idea and detail out my findings here. Is it possible for you to assign this case to me?

Could you explain this further? I don't fully understand what you mean.

I am sorry for not explaining this further, but I was under the idea that the team was closely following the factory based pattern. And for this, I think it is best to separate abstract methods into its own directory, rather than have it all in the modules subdirectory. This is simply for better readability.

Simple articles that I used when reading about this: https://medium.com/@rafalb/harnessing-python-design-patterns-for-machine-learning-a-dive-into-five-paradigms-c696d4970a37 https://medium.com/data-and-beyond/design-patterns-in-python-for-machine-learning-and-data-engineer-factory-pattern-78bcb209c2a6#:~:text=The%20factory%20is%20the%20class,in%20the%20file%20fmachine.py.

To elaborate, the MaskedVisionTransformer class, which is an abstract class should be in the base_modules or so directory.

Models subdirectory ├── src/ │ ├── modules/ │ │ ├── center.py │ │ ├── ijepa.py │ ├── base_modules/ │ │ ├── masked_vision_transformer.py │ │ ├── ... │ └── timm_models/ │ │ ├── ... │ └── torchvision_models/ │ │ ├── ...

Additionally, one can also create decorators which ensures the registry of sub-classes to the base modules to an additional layer of security, but I think, this can be a future feature. Another good source: https://medium.com/@geoffreykoh/implementing-the-factory-pattern-via-dynamic-registry-and-python-decorators-479fc1537bbe

guarin commented 6 days ago

Thanks for the extra info! We don't strictly follow the factory based pattern. We have some plans to update the package structure but they are still under discussion. Our goal is to make the package as easy as possible to use for research. This means that it should be straight forward to understand and adapt. We'll most likely try to reduce the usage of advanced design patterns to keep the code as simple as possible.

Shrinidhibhat87 commented 6 days ago

The findings from my initial analysis:

It is worth considering using an abstract class because there are several conceptual and functional similarities.

But of course, it is not all the same. I-JEPA is strictly using and calling this a predictor and not a decoder, which is what the MAE calls this block.

Quoting from the I-JEPA paper (https://arxiv.org/pdf/2301.08243):

Our encoder/predictor architecture is reminiscent of the generative masked autoencoders (MAE) [36] method. However, one key difference is that the I-JEPA method is non-generative and the predictions are made in representation space. [Chapter 3]

This clearly shows that the authors agree on the similarity of the predictor/encoder with the MAE architecture. Another point to note is that both the I-JEPA predictor and MAE decoder are shallow or lightweight.

@guarin you have already mentioned a rough idea of how the abstract class would look like. But again, based on the discussions we have had in this thread, it is clear that although we will have an abstract class for the decoder/predictor, we would still need to have individual classes for the I-JEPA predictor and MAE decoder.

Although not concrete, the base class for both the predictor and decoder would look like:

class MaskedVisionTransformerDecoder(nn.Module, ABC):
    def __init__(
        self,
        num_patches: int,
        embed_dim: int,
        depth: int,
        num_heads: int,
        mlp_ratio: float,
        drop_path_rate: float,
        proj_drop_rate: float,
        attn_drop_rate: float,
        norm_layer: Callable[..., nn.Module],
    ):
        super().__init__()

        self.embed = nn.Linear(embed_dim, embed_dim, bias=True)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim), requires_grad=False
        )
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=True,
                    drop_path=drop_path_rate,
                    proj_drop=proj_drop_rate,
                    attn_drop=attn_drop_rate,
                    norm_layer=norm_layer,
                )
                for _ in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Abstract forward method, to be implemented by subclasses."""
        pass

    def apply_transformer_blocks(self, x: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x = blk(x)
        return self.norm(x)

# The successive predictor/decoder logic
class IJEPAPredictorTIMM(BasePredictorDecoder):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)
        x = self.apply_transformer_blocks(x)
        # IJEPAPredictor-specific logic
        return x

class MAEDecoderTIMM(BasePredictorDecoder):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)
        x += self.pos_embed  # Adding positional embeddings
        x = self.apply_transformer_blocks(x)
        # MAEDecoder-specific logic for prediction
        return x

Pros and Cons: