facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.3k stars 699 forks source link

Can overfitting lead to high-norm patches? #419

Open amundra15 opened 1 month ago

amundra15 commented 1 month ago

I want to finetune vitb14 on domain-specific data, and as a proof-of-concept, I am doing so on a fairly small dataset in the beginning. The resultant patch features show high-norm artefacts similar to the ones discussed in "Vision transformers need registers".

What confuses me is that the paper highlights that such artefacts were not noticed for vitb14 but only larger more-representative models. This makes me wonder if I am seeing those artefacts for vitb14 as a sign of model overfitting.

Any thoughts on this?

heyoeyo commented 1 month ago

There does seem to already be high-norm artifacts in vit-b (more info in issue #373), though they present a bit differently than the larger model. Specifically for vit-b, there's always (weirdly) a bunch of high norm tokens in the top-left patches.

I'm not sure about finetuning on small datasets, but the vit-b model was also used within Depth-Anything, which would've involved training on a large dataset for a different task, and it still shows similar artifacts. I'd guess that the artifacts you're seeing may just be the original ones, especially if they're concentrated in the top-left patches, and not directly related to finetuning.

amundra15 commented 1 month ago

Thanks for your response, @heyoeyo.

In my case, the artefacts appear along the left and top edges of the image (and not just the top-left corner). What is also interesting is that I am getting low norm values for the artefacts, but high values for the first principal component.

Input RGB: input(1)

Norm of last layer patch tokens: our_norm_7000iter

PCA(n=1) of last layer patch tokens: our_fgbg

The values are overlayed in red. (ignore the mismatch in the image orientation).

I am not sure how to explain the low norm values for the artefacts.

heyoeyo commented 1 month ago

That seems very surprising! It's probably worth double checking if the original (not fine tuned) vit-b model produces similar artifacts (if you haven't already).

Another thing worth checking is whether the artifacts appear on earlier layers, and what that pattern looks like. In all cases I've seen, they aren't present on the earlier layers, but tend to appear and stay consistent on later layers, with the final layer being somewhat different from all others. Not that this will explain your results specifically, but if you see a similar pattern it may be more of a hint that it's the same phenomena at least, even though you're getting low norm tokens.

amundra15 commented 1 month ago

The original model also produces similar low-norm artefacts (though not as evidently).

Norm of last layer patch tokens from official vitb14: dino_norm

PCA(n=1) of last layer patch tokens from official vitb14: dino_fgbg(1)

It's interesting to note that the original model shows regions of high as well as low norms. The fine-tuning is exacerbating the low-norm problem already present in the top-left region. Is this phenomenon studied and documented somewhere?

I will also add the visualizations from the other layers once I have them.

heyoeyo commented 1 month ago

Those results from the original model look a bit more similar to what I've seen, though it's strange that it seems inverted and that there are other tokens (not just the top-left) that seem out-of-place. It actually resembles the result from the larger models... Are these norm results showing the patch tokens as-is, or does it also include the final layernorm step? If the layer norm is included, I wonder if that might explain the inversion of low/high norms at least?

As for visualizing the other layers, there's some code here which can at least give a qualitative result.

amundra15 commented 3 weeks ago

@heyoeyo you are right regarding the final layer norm. Upon commenting it out, I get high norms as expected:

Norm of last layer patch tokens from official vitb14: ori_vitb14_patchnorm

Norm of last layer patch tokens from our fine-tuned vitb14: ours_vitb14_featurenorm

The values are now similar to the ones discussed in the registers paper. However, I still observe artefacts along the entirety of the top and left edges.

I have a couple of queries regarding the impact on performance:

  1. Does this artefact somehow affect the cls token performance as well?
  2. What happens if I mask/threshold the patch token outputs for the artefacts? Can this ensure better downstream performance compared to using it as is?
heyoeyo commented 3 weeks ago

Does this artefact somehow affect the cls token performance as well?

I'd guess this depends a lot on how you're using the cls token. If you've trained another model to use the cls token (and included the vitb model in this training), then I'd imagine it's ok. The cls token has the chance to 'attend' to these weird high norm tokens throughout the model, so even if they include global info (as the registers paper suggests), end-to-end training involving the cls token should be able to account for this (to some extent) I think. On the other hand, if you're attaching a separate model/other classifier onto the vitb cls token without further training of the vitb model, it may perform more poorly since there is nothing guiding the model to place the most relevant info into the cls token specifically.

What happens if I mask/threshold the patch token outputs for the artefacts? Can this ensure better downstream performance compared to using it as is?

I think this again depends on how the model is used. If the downstream processing is trained in conjunction with the vit output, then it's likely to outperform any hand-picked mask/thresholding settings (given how hard it would be to do this with these odd patterns). The Depth-Anything models have a substantial amount of post-processing after the vit encoding and seem to perform fine at least. Though of course there's always a chance that it could be even better if not for these weird patterns, though I think the only way to know that is to try with different models (especially the ones with registers in this case).

amundra15 commented 5 days ago

A short update regarding the issue: The issue is related to training instability as well (somewhat expected). After making a modification to our (custom) loss function, we observe more stable training. This leads to no high norm patches in the feature space, as well as better downstream performance.