RobvanGastel / dinov2-finetune

Testing adaptation of the DINOv2 encoder for vision tasks with Low-Rank Adaptation (LoRA)
MIT License
80 stars 9 forks source link

Learning Method #3

Closed Jmipar-k closed 2 months ago

Jmipar-k commented 2 months ago

Hello!

Thank you for the great work

I am trying to adapt this to a multi-classification task.

Was the segmentation finetuning done in a supervised manner?

If no, would it be possible to load self-supervised learning trained ViT weights(The backbone weight from DINOv2 repo)

and freeze them, use lora and classifier heads(both trainable) to fine tune in a supervised manner?(with images, and class labels)

I am trying to fine tune on my custom dataset.

Thank you.

RobvanGastel commented 2 months ago

Hi! The segmentation fine-tuning was done in a supervised manner. While fine-tuning, I kept the encoder ViT weights unchanged and only updated the decoder head and the LoRA weights.

For multiclass classification you could write a custom dataloader and change the 1x1 convolution decoder to a linear layer with shape (Embedding dim, # classes). I hope I answered your question.

Jmipar-k commented 2 months ago

Were there no problems even though you started with self-supervised pretrain weights(for ViT backbone) and fine-tuned them in a supervised manner?

Thank you for the fast reply!

Jmipar-k commented 2 months ago

Oh, and if I use --use_fpn option as false, do I get to use the classifier head(decoder)?

RobvanGastel commented 2 months ago

I had no problems starting from these weights, it made fine-tuning faster. For the second question, I assume you want to do classification and output a single value to indicate the class in the image. And not segmentation e.g. classifying which pixels belong to which class in the image. I think if you change the LinearClassifer to be something like this,

import` torch
import torch.nn as nn

class MulticlassClassifier(nn.Module):
    def __init__(
        self,
        channels: int,
        patch_h: int = 35,
        patch_w: int = 35,
        n_classes: int = 1000,
    ):
        """The classifier decoder

        Args:
            channels (int): Number of input channels
            patch_h (int, optional): The height patch size, essentially
                image height // patch size of the encoder. Defaults to 35.
            patch_w (int, optional): The width patch size, essentially
                image width // patch size of the encoder. Defaults to 35.
            n_classes (int, optional): Number of output classes. Defaults to 1000.
        """
        super().__init__()
        self.width = patch_w
        self.height = patch_h
        self.channels = channels
        self.classifier = nn.Linear(channels * patch_h * patch_w, n_classes)

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        embeddings = embeddings.reshape(-1, self.height * self.width * self.channels)
        return self.classifier(embeddings)

It should work for classification tasks.

Something else you could try is instead of looking at all the patches, look at only the class token in the ViT to do classification. The dimensions would change but you would in this line: https://github.com/RobvanGastel/dinov2-finetune/blob/e1815206165520b359e502a69fafa9c864303c41/dino_finetune/model/dino_v2.py#L125

Take this key x_norm_clstoken from the dictionary instead, it would require making some more changes to make the dimensions match.

Jmipar-k commented 2 months ago

thank you for your help.

I have successfully started training.

I have several additional questions.

def forward(self, x: torch.Tensor) -> torch.Tensor:

    # If the FPN decoder is used, we take the n last layers for
    # our decoder to get a better segmentation result.
    if self.use_fpn:
        # Potentially even better to take a different depths
        feature = self.encoder.get_intermediate_layers(
            x, n=self.inter_layers, reshape=True
        )
        logits = self.decoder(feature)

    else:
        feature = self.encoder.forward_features(x)
        # get the patch embeddings - so we exclude the CLS token
        patch_embeddings = feature["x_norm_patchtokens"]
        logits = self.decoder(patch_embeddings)

    # logits = F.interpolate(
    #     logits,
    #     size=x.shape[2:],
    #     mode="bilinear",
    #     align_corners=False,
    # )
    return logits
  1. from this partial code(forward of DINOV2EncoderLoRA) i do not understand why interpolation is needed, so i got rid of that part would it be okay?

  2. can you give me advice on training recipes? I would like to know which hyper paramaters you used, and why it was adopted(rough advices)

Thank you so much again for the detailed replies!!

RobvanGastel commented 2 months ago

No problem, I am happy to help 😄.

  1. Yes, this was only to upscale the output dimensions for segmentation, e.g. the decoder output of shape (classes, H/8, W/8) to full resolution (classes, H, W).

  2. For training I did not do much hyperparameter tuning, I take the maximum batch size my GPU can handle and then tune my learning rate. The other parameter you could consider is tuning the LoRA rank parameter r, keeping it low around 3-6 helped me converge training faster. A higher r made fine-tuning take longer and I did not run long enough to see if the performance improved over lower values for r.

Maybe this repository is helpful to you in tuning your model, https://github.com/google-research/tuning_playbook.

Jmipar-k commented 2 months ago

Thank you for the help, I am glad that you are happy to help.

I've finished fine-tuning with my initial_dataset(for test experiments) and the results are pretty good!

I could not find an inference(evaluation) code for my test dataset, so I am currently performing inference by building the original

ViT encoder with LoRA parameters, then loading the pretrained weights for ViT and LoRA, respectively.

This needs to load both models on my GPU for inference, and also for my future training framework.

I am aware that this is not a memory, latency efficient way.

I think I have seen that combining the two parameters(model) is possible (by matrix multiplication? according to my assumption)

I am curious if you have gone through the same problems(worries) with me or found solutions to this.

RobvanGastel commented 2 months ago

This is an interesting idea. I have not tried combining the LoRA parameters with the ViT encoder layers, but I think you're right that in some scenarios, something like fusing might be possible. If my understanding is correct, Huggingface/diffusers do something like this right here

https://github.com/huggingface/diffusers/blob/bbcf2a8589f93acd401bd9e6367add6412eabc04/src/diffusers/models/lora.py#L119

Let me know if you managed to get it working!

Jmipar-k commented 2 months ago

Ok

I will update you if I get further knowledge!

Big thanks for everything!!

Oguzhanercan commented 1 month ago

@Jmipar-k did you able to fuse the lora weights with Vit layers? If you did, can you provide preparation and inference code for that?

Jmipar-k commented 1 month ago

Hi, @Oguzhanercan actually I haven't done fusing the weights yet, since latency is not really critical in my environment.

I have recently read the LoRA paper and actually merging the weights is possible.

It could be done by this equation theoretically:

W = W_0 + BA, [ W_0 -> pre-trained(initial) weights, BA -> B and A ranks of LoRA weights ]

so, i am not sure how this could be done in python codes, but according to the link @RobvanGastel provided me above,

def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
        if self.lora_linear_layer is None:
            return

        dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device

        w_orig = self.regular_linear_layer.weight.data.float()
        w_up = self.lora_linear_layer.up.weight.data.float()
        w_down = self.lora_linear_layer.down.weight.data.float()

        if self.lora_linear_layer.network_alpha is not None:
            w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

        fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

This is the snippet of the weight fusing function based on the equation I wrote above.

I hope this helps you.