joanaapa / Foundation-Medical

MIT License
10 stars 2 forks source link

Linear classifier fine-tuning #2

Closed aleemsidra closed 1 month ago

aleemsidra commented 1 month ago

Thank you releasing the code. I am trying to navigate through the code to understand how embeddings are passed from the specifically for DinoV2 foundation model to the linear layer. I have the following questions regarding that:

If my understanding is correct, first in DINOWRAPPER, a linear layer is added to DinoV2 by calling the Classifier

https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/wrappers.py#L461

After that Ymodel_DINO is called where the embeddings are first passed to a linear projector before passing it to the linear layer (termed as classifier in the paper) for classification as follows:

https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/models.py#L469 https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/models.py#L470

These linear_proj are passed to the final linear layer which was added in DINOWRAPPER for classification? Can you please pin point where are you fine-runing the linear layer for final classifciation?

Moreover, why not all the embeddings were extracted by using:

found_embeds = self.foundation_model(x).mean(dim=1)

why the cls tokens were removed: found_embeds = self.foundation_model(x)[:,1:,:]

Questions from the paper: "In all cases, integrating a classifier model on top of the foundation model involves extracting patch representations from the foundation models, adding a projection layer, bypassing the patch embedding layer of the stacked classifier, and transmitting the patch representations to the classifier for the fine-tuning process."

Here classifier model means a linear layer for final classification: https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/models.py#L57

Projection layer is added before it refers to this one:

linear_proj = self.linear_projector(found_embeds)

bypassing the patch embedding layer of the stacked classifier -> what does it mean?

joanaapa commented 1 month ago

Hello,

I understand your confusion. The DinoWrapper is only called when the Classifier is another full model (i.e. DeiT). If no classification model is defined, and is only the foundation model + linear layer, the FoundationWrapper is used: https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/classification.py#L121-L124 In this case, as you pointed out, there is no need for the projection layer.

The cls token is removed when using the foundation model together with another full classifier. For the final classification, the cls token of the stacked model is used. The rationale behind this is that we consider the foundation model to be a feature extractor.

In reference to this

bypassing the patch embedding layer of the stacked classifier

Again we are referring to the setting when we use a DeiT model as a classifier. As we're using the foundation model as feature extractor, we need to remove the embedding layer of the DeiT. This is not applicable to the scenarios when the linear layer was used.

Hope that was clear, Joana

aleemsidra commented 1 month ago

Thank you for the clarification, that means that for the linear layer classification, first embeddings are extracted from the models, and these embeddings are passed to the linear layer for classification as done here:

https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/wrappers.py#L570C17-L570C27 https://github.com/joanaapa/Foundation-Medical/blob/a7aa194f0643583f29fc9d32f211e2815400709b/defaults/models.py#L51

both embedding and cls_token is used in linear layer classification, right?

joanaapa commented 1 month ago

For the classification only the cls_token is used

aleemsidra commented 1 month ago

Thanks for your response. One last question, the resultd reportrd in the part are with large or base version of foundation models?

joanaapa commented 1 month ago

It is the base version