dgcnz / edge

0 stars 0 forks source link

Adapt dinov2 as a detectron2's Backbone #4

Open dgcnz opened 1 month ago

dgcnz commented 1 month ago

Relevant DinoV2 snippet

class DinoVisionTransformer(nn.Module):
    ....
    def forward_features(self, x, masks=None):
        if isinstance(x, list):
            return self.forward_features_list(x, masks)

        x = self.prepare_tokens_with_masks(x, masks)

        for blk in self.blocks:
            x = blk(x)

        x_norm = self.norm(x)
        return {
            "x_norm_clstoken": x_norm[:, 0],
            "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
            "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
            "x_prenorm": x,
            "masks": masks,
        }

Notes:

Relevant detectron2 snippet

class ViT(Backbone):
    ...
    def forward(self, x):
        x = self.patch_embed(x)
        if self.pos_embed is not None:
            x = x + get_abs_pos(
                self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
            )

        for blk in self.blocks:
            x = blk(x)

        outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
        return outputs
dgcnz commented 1 month ago

Officially, dinov2 is used as segmentation backbone with the help of the get_intermediate_layers function which default norm=True, cls=False.

dinov2_vitb14.get_intermediate_layers(x, n=1, norm=True, reshape=True)

Which returns a single-element tuple with a tensor entry of shape [1, 768, 37, 37]

dgcnz commented 1 month ago

Both in the case of ViT and EVA-02 for detrex/detectron2, the cls token and last layernorm are not accounted for, so the first thing to try is to also not forward it to the layernorm.

dgcnz commented 1 month ago

Pre-norm equivalents

dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').eval()
dinov2_timm = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=True, features_only=True, out_indices=(-1,)).eval()

z1 = dinov2.get_intermediate_layers(x, norm=True, reshape=True, return_class_token=False)[0]
z2 = dinov2.forward_features(x)["x_prenorm"]["x_prenorm"][:, 1:, :].unflatten(1, (37, 37)).permute(0, 3, 1, 2)
z3 = dinov2_timm.forward(x)[0]
dgcnz commented 1 month ago

Also equivalent to detrex.TimmBackbone:

backbone = TimmBackbone(
    model_name="vit_base_patch14_dinov2.lvd142m",  # name in timm
    features_only=True,
    pretrained=True,
    in_channels=3,
    out_indices=(-1, ),
)
backbone(x)["p-1"]
dgcnz commented 1 month ago

Regarding the cls token, started a discussion on detectron2 here

dgcnz commented 1 month ago

More silly issues:

dgcnz commented 1 month ago

Found a couple of options:

  1. EVA's manual interpolation: https://github.com/baaivision/EVA/blob/master/EVA-01/eva/interpolate_patch_14to16.py
  2. timm's patch_embed interpolation on loading: https://github.com/huggingface/pytorch-image-models/blob/3196d6b131dd89ac0bf343efb039025fdb895efa/timm/models/vision_transformer.py#L1127-L1133

Making a small comparison, there doesn't seem to be much of a difference between both approaches: interpolations

dgcnz commented 1 month ago

The load_pretrained method with interpolation can also be implicitly called like this:

dinov2 = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=True, patch_size=16, features_only=True, out_indices=(-1, )).eval()
dgcnz commented 1 month ago

tracking at: https://github.com/dgcnz/edge/blob/main/notebooks/dinov2_interpolation.ipynb