lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
3.16k stars 283 forks source link

Adding DINO to lightly #698

Closed Atharva-Phatak closed 2 years ago

Atharva-Phatak commented 2 years ago

First of all Kudos on creating this amazing library ❤️.

I think all of us in self supervised learning community have heard of DINO. For the past couple of weeks, I have been trying to port the DINO implementation in facebook's implementation to PL. I have implemented it atleast for my use case which is a kaggle competition. I initially looked at lightly for implmenetation but I did not see any, so I borrowed and adopted code from original implementation and converted it into lightning.

Honestly it was a tedious task. I was wondering if you guys would be interested in adding DINO to lightly.

Here's how I think we should structure the implementations

1) Augmentations : Dino heavily relies on multi-crop strategy. Since lightly already has a collate_fn for implementing augmentations, dino augmentations can be implemented in the same way. 2) Model forward pass and Model Heads : The forward pass for the model is weird since we have to deal with multicrop and global crops, so this needs to be implemented as a nn.Module like other heads in lightly. 3) Loss Functions : I am not sure how this should be implemented for lightly, although FB have a custom class for that. 4) Utility functions : FB has used some tricks for stable training and results, so these need to be included as well.

I have used lightning to implement all this and so far at least for my use case I was able to train vit-base-16 due my hardware constraints.

This goes without saying I would personally like to work on PR :heart:

Please let me know.

guarin commented 2 years ago

Hi @Atharva-Phatak! Thanks for reaching out!

We recently implemented DINO but did not announce it yet. All the code should be available in the latest release (1.2.7) and you can find the docs on how to use it here: https://docs.lightly.ai/examples/dino.html

We implemented the following parts:

I hope this helps you! Let us know if you have any questions or feedback :)

Atharva-Phatak commented 2 years ago

So cool. Amazing work guys ! Code looks so clean.

I was thinking maybe if we could utilities to visualize attention maps for vit-models trained using SSL ? Just like DINO did and the cool video they released.

That will be the cherry on cake in my opinion if we added the visualization ?

What do you think ? I am very motivated to make contributions to lightly !

guarin commented 2 years ago

We thought about making a tutorial on how to use DINO with a transformer and visualize the results. Let us know if you would be interested in it. Contributions are of course always very welcome!

Atharva-Phatak commented 2 years ago

Yes I am interested in it. I can make a tutorial which uses visualization that will be fun to do. Could you please give me some instructions on how to proceed ?

guarin commented 2 years ago

Hi, that sounds great!

As an outline I would propose the following two steps:

  1. Show how to train DINO with a transformer (you can take the code from the examples)
  2. Visualize the transformer self attention for some images

For the visualization I would maybe refactor the original code into one or two easy to use functions. You could of course also use another tool for the visualization.

I guess a jupyter notebook or google colab would be the easiest way to share the code so you don't have to build our docs from scratch. Let me know what you think :)

Atharva-Phatak commented 2 years ago

That sounds, I should have that up and running quickly. Any data-sets you would recommend ? Maybe cifar10 ?

Also quick question, in the lightning implementation of Dino you guys seem to update momentum in the training step, but shouldn't it be done in on_train_batch_end hook in lightning ? Because even in original implementation they do it after optimizer.update(). Also you they have a training strategy where they cancel gradients of last layer, any plans on implementing ?

Please let me know.

guarin commented 2 years ago

Any data-sets you would recommend ? Maybe cifar10 ?

Cifar10 looks good. We can always change the dataset if it does not work well. Imagenette would be another option as it has larger images but has fewer images than imagenet.

Edit: I am actually not sure if Cifar10 works with the ViT backbone as it expects 224x224 images and 16x16 patches. So we probably have to go for Imagenette.

Edit 2: Nvm, you can set the image size to 32 when loading the backbone: torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False, image_size=[32])

Also quick question, in the lightning implementation of Dino you guys seem to update momentum in the training step, but shouldn't it be done in on_train_batch_end hook in lightning ? Because even in original implementation they do it after optimizer.update()

This should not change anything, as we still run the update between two training steps. Whether this is right before the step or right after a step does not matter.

Also they have a training strategy where they cancel gradients of last layer, any plans on implementing ?

Adding it makes training a bit more stable. We did not add it for simplicity but maybe it would make sense to add it as an extra method on the DINOHead 🤔 :

class DINOHead:
    ....
    def cancel_gradient_last_layer(self):
        self.last_layer.grad = None

and then we could simply call it in the training loop during the first epoch.

Atharva-Phatak commented 2 years ago

@guarin Maybe I can create a small PR ❤️ to add cancel_gradient_last_layer to DINO Head ? Then I will create an example which utilizes everything ?

guarin commented 2 years ago

Yes that would be great!

Atharva-Phatak commented 2 years ago

Awesome I will create a PR :)

philippmwirth commented 2 years ago

I think we can close this one 🙂