facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.91k stars 783 forks source link

Cannot identify high-norm tokens #373

Open riccardorenzulli opened 7 months ago

riccardorenzulli commented 7 months ago

Hello,

I'm having trouble identifying high-norm tokens, as mentioned in the "Vision Transformers Needs Registers" paper. I've seen that it is also mentioned at https://github.com/facebookresearch/dinov2/issues/293.

I used the L2 norm and the code from https://github.com/facebookresearch/dinov2/pull/306. To get the embedding vectors of the last layer, I use
x_layers = model.get_intermediate_layers(img, [len(model.blocks)-1]).

I tried with ViT-G/14 on the full ImageNet validation set with and without registers; however, as you can see see in the images below, the norms of the model without registers are not higher than 150, as written in the paper.

image image

Did anyone succeed in reproducing the results of the main paper and identifying these high-norm tokens?

AndreaBrg commented 7 months ago

I'm having the same issue; @patricklabatut, any idea?

heyoeyo commented 7 months ago

For what it's worth, I've seen this 'high norm' pattern occur with the dinov2-based image encoder used in the depth-anything model. It happens on the vit-l, vit-b and to some extent even on the vit-s model. A similar pattern appears using the 'ViT-L/14 distilled' backbone (from the dinov2 listing), but it's only visible on internal blocks.

Here are the norms of the different output blocks for vit-l (the depth-anything version) running on a picture of a turtle:

vit-l block norms ![blocknorms_depth_anything_vitl14_turtle_504](https://github.com/facebookresearch/dinov2/assets/32405350/ca86327c-2741-4bd9-a061-4af0aa7a418c) Here are the reported min/max norms for the last few blocks: | Block | Min | Max | | ------------| ---------| -------- | | 14 | 4.68 | 6.44 | | 15 | 5.87 | 9.12 | | 16 | 6.98 | 11.04 | | 17 | 9.17 | 175.08 | | 18 | 12.43 | 320.31 | | 19 | 16.54 | 509.3 | | 20 | 22.94 | 517.46 | | 21 | 33.25 | 532.22 | | 22 | 50.88 | 569.05 | | 23 | 99.46 | 389.04 |

Here are some more examples:

vit-l block norms at half resolution ![blocknorms_depth_anything_vitl14_turtle_252](https://github.com/facebookresearch/dinov2/assets/32405350/12b8a517-1c13-42b0-b1dd-7b6075ad1165)
vit-b block norms ![blocknorms_depth_anything_vitb14_turtle_504](https://github.com/facebookresearch/dinov2/assets/32405350/3d56ca8a-be2e-4387-96d3-4a4a3a633a30)
vit-s block norms ![blocknorms_depth_anything_vits14_turtle_504](https://github.com/facebookresearch/dinov2/assets/32405350/852c2f41-c8ff-4794-bac8-3711f2bcbf72)
Original ViT-L/14 distilled block norms ![blocknorms_dinov2_vitl_orig_turtle_504](https://github.com/facebookresearch/dinov2/assets/32405350/e2c71a41-fce5-4dd1-99db-4bd6a5703941) And the last few block min/max norms for comparison: | Block | Min | Max | | ------------| ---------| -------- | 14 | 4.59 | 7.11 15 | 5.83 | 9.58 16 | 7.13 | 10.4 17 | 9.51 | 188.77 18 | 12.61 | 342.56 19 | 16.85 | 538.95 20 | 23.87 | 549.13 21 | 30.23 | 563.75 22 | 39.58 | 618.84 23 | 66.03 | 121.01
beit-large-512 block norms ![blocknorms_dpt_beit_large_512_turtle_512](https://github.com/facebookresearch/dinov2/assets/32405350/2ba7cb30-7157-4972-8de7-cb7d5acee4e2)
input image (downscaled for display) ![turtle](https://github.com/facebookresearch/dinov2/assets/32405350/31eb9ae6-016d-497f-9131-ac91a9a03a91)

Some notes:

Obviously it's not a conclusive result, I've only tried this on a few images, but it does seem similar to the effect described in the 'register' paper.

heyoeyo commented 7 months ago

As a quick follow-up, I've tried this with the original dinov2 model & weights and got the same results. The original weights always have smaller norms on their final output (compared to the depth-anything weights), but vit-b & vit-l both show high norms internally. Results from vit-g have high norms even on the final output.

Here is an animation of the vit-g block norms (first-to-last) showing qualitatively similar results to the paper: vitg_blocknorm_anim

The 'with registers' versions of the models don't completely get rid of high norms in the later layers, but they do get rid of outliers.

For anyone wanting to try this, here's some code that uses the dinov2 repo/models and prints out the min & max norms for each block. Just make sure to set an image path and model name at the top of the script (use any of the pretrained backbone names from the repo listing):

Code for printing block norms ```python import cv2 import torch import numpy as np from dinov2.layers.block import Block # Setup model_name = "dinov2_vitl14" # compare with "dinov2_vitl14_reg" image_path = "path/to/image.jpg" device, dtype = "cuda", torch.float32 img_size_wh = (518, 518) img_mean, img_std = [0.485,0.456,0.406], [0.229, 0.224, 0.225] # Load & prepare image orig_img_bgr = cv2.imread(image_path) img_rgb = cv2.cvtColor(orig_img_bgr, cv2.COLOR_BGR2RGB) img_rgb = cv2.resize(img_rgb, dsize=img_size_wh) img_rgb = (np.float32(img_rgb / 255.0) - img_mean) / img_std img_rgb = np.transpose(img_rgb, (2, 0, 1)) img_tensor = torch.from_numpy(img_rgb).unsqueeze(0).to(device, dtype) # Load model model = torch.hub.load("facebookresearch/dinov2", model=model_name) model.to(device, dtype) model.eval() # Capture transformer block outputs captures = [] hook_func = lambda m, inp, out: captures.append(out) for m in model.modules(): if isinstance(m, Block): m.register_forward_hook(hook_func) with torch.no_grad(): model(img_tensor) # Figure out how many global tokens we'll need to remove # (assuming we only get norms of image-patch tokens) has_cls_token = model.cls_token is not None num_global_tokens = model.num_register_tokens + int(has_cls_token) # Print out norm info print(f"Block norms (min & max) for {model_name}") for idx, output in enumerate(captures): patch_tokens = output[:, num_global_tokens:, :] # Remove cls & reg tokens norms = patch_tokens.norm(dim=2).cpu().float().numpy() min_str = str(round(norms.min())).rjust(3) max_str = str(round(norms.max())).rjust(3) print(f"B{idx}:".rjust(4), f"[{min_str}, {max_str}]") ```

And here's some code that can be added to the end of the code above for generating the visualizations (it pops up a window, so you need to be running the code locally).

Code for visualizations ```python # Figure out patch sizing, for converting back to image-like shape input_hw = img_tensor.shape[2:] patch_size_hw = model.patch_embed.patch_size patch_grid_hw = [x // p for x, p in zip(input_hw, patch_size_hw)] # For displaying as an image for idx, output in enumerate(captures): # Get tokens into image-like shape patch_tokens = output[:, num_global_tokens:, :] imglike_tokens = torch.transpose(patch_tokens, 1, 2) imglike_tokens = torch.unflatten(imglike_tokens, 2, patch_grid_hw).squeeze().float() imglike_norms = imglike_tokens.norm(dim=0) # Make image easier to view min_norm, max_norm = imglike_norms.min(), imglike_norms.max() norm_disp = ((imglike_norms - min_norm) / (max_norm - min_norm)).cpu().float().numpy() norm_disp = cv2.resize(norm_disp, dsize=None, fx=8, fy=8, interpolation=cv2.INTER_NEAREST_EXACT) cmap_disp = cv2.applyColorMap(np.uint8(255*norm_disp), cv2.COLORMAP_VIRIDIS) cv2.imshow("Block norms", cmap_disp) cv2.waitKey(250) cv2.destroyAllWindows() ```
AndreaBrg commented 7 months ago

@heyoeyo Thanks for the thorough explanations. I'll take a look.

riccardorenzulli commented 7 months ago

Thank you very much @heyoeyo for your help and insights. We discovered that the problem in our code was the default value set to True for the norm argument in x_layers = model.get_intermediate_layers(img, [len(model.blocks)-1]). By adding norm=False and collecting the embeddings for all layers, we get the same results as yours.

As you pointed out, surprisingly, the norms in a model without registers in the last layers are not that high, while for the model with registers, the norms become high but without outliers. I was surprised about this, especially given Figures 7 and 15 of the paper.

heyoeyo commented 7 months ago

I was surprised about this, especially given Figures 7 and 15 of the paper.

Agreed! The output layer of the vitl-reg model has norms in the 150-400 range for the few images I've tried, as opposed to the <50 range reported by the paper.

I also find figure 3 vs 7 & 15 to be confusing, as fig. 3 suggests a non-register high-norm range of ~200-600 (consistent with what I've seen), whereas fig 7 & 15 show a 100-200 range for the high-norm tokens. Though I may be misinterpreting the plots.