facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.85k stars 774 forks source link

Semantic segmentation #25

Closed anshkumar closed 1 year ago

anshkumar commented 1 year ago

I'm not able to find code for Semantic segmentation. In the paper it's written that:

 a linear layer is trained to predict class logits from a patch tokens. It is used to produce a low-
resolution logit map (eg 32x32 for a model with patch size 16), which is then upsampled to full resolution
(512x512) to obtain a segmentation map. 

Does this mean a Linear layer with 32*32 = 1024 output classes need to be trained? What about n_last_blocks_list = [1, 4] and n_last_blocks = max(n_last_blocks_list) ? Does that need to be changed to n_last_blocks_list = [1, 1] and n_last_blocks = max(n_last_blocks_list) ?

Is there any sample code for semantic segmentation ?

woctezuma commented 1 year ago

This is the paragraph. I don't think 32x32 is linked to the number of classes. It is the low-resolution of the logit map.

Paper

The dataset seems to be ADE-20k, which should have 3688 classes.

Paper

ccharest93 commented 1 year ago

32 x 16 = 512 so starting with a cropped image 512x512 pixel you would end up with [Batch size, # of patches, # of classes]. So [1, 32x32, #of classes]. Where # of classes would be the classes you fine tune it on.

I dont think you want to touch the intermediate layers, just train a head that learns the mapping btw the output of the transformer stack to segmentation label.

anshkumar commented 1 year ago

In the Linear class, I can do the following:

nn.Linear(in_dim, 32x32)

Where, to get that # of classes in the output dims ?

woctezuma commented 1 year ago

You probably would prefer 32*32*N_cls to predict a 32x32 logit map for N_cls classes.

See for instance how it is written for SegFormer: (H/4)*(W/4)*N_cls

SegFormer

SegFormer

To upsample the map and take the argmax, you may refer to 🤗's doc about Semantic segmentation.

Take everything I write with a grain of salt though.

Alexankharin commented 1 year ago

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()
anshkumar commented 1 year ago

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)
Alexankharin commented 1 year ago

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise. If that is so, 1x1 convolution on unrolled patches is mathematically equivalent to linear classification over patch features

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)
woctezuma commented 1 year ago

The fact that the layer is linear does not really matter, it is just a way to say that DINOv2's frozen features are really good, so that you can train a simple head and get good results. 😄

If Alexankharin's simple code gives good results, then it is fine. Plus the explanation is probably correct.

TimDarcet commented 1 year ago

The fact that the layer is linear does not really matter, it is just a way to say that DINOv2's frozen features are really good, so that you can train a simple head and get good results. 😄

Exactly. Any valid head should be fine. Linear is the easiest to train, but a larger one will get better results.

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise.

Spot on too. The linear head is applied separately to each patch token, ie it is also a 1x1 convolution.

patricklabatut commented 1 year ago

Closing as answered (and keeping track in #55).

pranavraja99 commented 1 year ago

can confirm that a 1x1 convolution on unrolled patches is mathematically equivalent to a linear layer. No information on neighboring patches is considered in encoding each patch and there are no edge effects due to the 1x1 kernel so there is no need for padding. Number of parameters and their input and outputs are exactly the same.

arkadaz commented 1 year ago

I borrow from Alexankharin and U Net comcept to decode it

class LinearClassifierToken(nn.Module):
    def __init__(self, in_channels, num_chanel=2, tokenW=32, tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=num_chanel
        self.conv=torch.nn.Conv2d(in_channels,num_chanel,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))
class DinoV2(nn.Module):
    def __init__(self, num_class=1) -> None:
        super().__init__()
        self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
        for param in self.dinov2.parameters():
            param.requires_grad = False
        n=512
        self.classlayer_448 = LinearClassifierToken(in_channels=768,num_chanel=n,tokenW=32,tokenH=32)
        self.classlayer_224 = LinearClassifierToken(in_channels=384,num_chanel=n,tokenW=16,tokenH=16)
        self.selu = nn.SELU()
        self.to_448 = nn.Sequential(
            nn.Conv2d(n,n,kernel_size=7,stride=1,padding=1,bias=False),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n,n//2,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//2),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//2,n//4,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//4,n//8,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//8,n//16,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True)
        )
        self.to_224 = nn.Sequential(
            nn.Conv2d(n,n,kernel_size=5,stride=1,padding=1,bias=False),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n,n//2,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//2),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//2,n//4,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//4,n//8,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(n//8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//8,n//16,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True)
        )
        self.conv2class = nn.Conv2d(n//16,num_class,kernel_size=3,stride=1,padding=1,bias=True)
    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2.forward_features(x.to("cuda"))['x_norm_patchtokens']
        x = self.selu(self.classlayer_224(features))
        x = self.to_224(x)
        x = self.conv2class(x)
        return x
YScheung commented 1 year ago

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()

Hi, I would like to know if this means ground truth segmentation label of the images are needed ? If so, is it possible to peform unsupervised semantic segmentation with DINOv2 ? Many thanks

arkadaz commented 1 year ago

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()

Hi, I would like to know if this means ground truth segmentation label of the images are needed ? If so, is it possible to peform unsupervised semantic segmentation with DINOv2 ? Many thanks

Yes you need to have the label. For unsupervised segmentation i recommend SAM (segment anything)

itsprakhar commented 1 year ago

In the Linear class, I can do the following:

nn.Linear(in_dim, 32x32)

Where, to get that # of classes in the output dims?

You can use a conv layer instead. Use the number of classes as the number of out channels. Cheers!

here is a sample code

class SegmentationModel(nn.Module):
    def __init__(self, mask_dim=64, num_classes=3):
        super().__init__()

        self.mask_dim = mask_dim
        self.num_classes = num_classes
        # Load the DINO model
        self.dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
        self.dino.cuda()

        # Freeze DINO layers
        for param in self.dino.parameters():
            param.requires_grad = False

        self.segmentation_conv = nn.Sequential(
            nn.Conv2d(384, self.num_classes, kernel_size=1),
        )

    def forward(self, x):
        batch_size = x.shape[0]
        with torch.no_grad():
            x = self.dino.forward_features(x.cuda())
            x = x['x_norm_patchtokens']
            x = x.permute(0,2,1)
            x = x.reshape(batch_size,384,self.mask_dim,self.mask_dim)
        x = self.segmentation_conv(x)
        x = x.reshape(batch_size,self.mask_dim*self.mask_dim)
        x = torch.sigmoid(x)
        return x
DuongTSon commented 1 year ago

Thanks all for the guidelines and the DINOV2 team for the release of the pre-trained model.

I have managed to train a semantic segmentation model in my domain it has achieved exceptional performance, quite close to human capability. The surprise is that the training dataset was just a dozen of masked labels.

Do we have the explanation for this high performance in few-shot learning? Just curious!

ccharest93 commented 1 year ago

That is a good question! If you let me speculate for a bit:

-Semantic segmentation could be a combination of two tasks: depth approximation and object classification. That is because when doing semantic segmentation, depth approximation could provide good masks (good object boundary estimation) and object classification could provides a mean to differentiate between the masks produced

Now if we think of DINOv2, its pretext task (combination of DINO and IBOT) forces it to estimate the same object embedding when given two overlapping crops of an image. From this pretext, object classification could be learned through the loss (centering/sharpening that prevents mode collapse "see DINO paper") and forces different object to have different classification. The depth approximation might come from the need to pick a common focus within an image, by that i mean that given slightly different crops, to minimize loss the model needs to come up with a policy to select which object is in focus, (an example might be to always take the object in the foreground), learning this type of policy would lead to depth approximation.

Again this is all speculation, after all models are often blackboxes with emergent properties, but i think it is still interesting to discuss why we believe the properties emerge to guide further design. Let me know what you think!

NielsRogge commented 1 year ago

Hi folks,

Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb.

DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

PeterKim1 commented 11 months ago

Thanks all for the guidelines and the DINOV2 team for the release of the pre-trained model.

I have managed to train a semantic segmentation model in my domain it has achieved exceptional performance, quite close to human capability. The surprise is that the training dataset was just a dozen of masked labels.

Do we have the explanation for this high performance in few-shot learning? Just curious!

Hi, @DuongTSon, I want to use DINOv2 in my domain, but performance very low. Myabe I think it's my codes faults.

Could you please share your codes?

DuongTSon commented 11 months ago

@PeterKim1 Hi, I cannot share the code since it was a project in my company. However you can take a look at this repository https://github.com/itsprakhar/Downstream-Dinov2, it covers the basic structure of a DINO-based models. Some experiences below I have gained when using DINOV2:

Hope it can help you!

mzschwartz88 commented 9 months ago

Hi folks,

Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb.

DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

Thank you much for this tutorial! Having a ton of fun using it for experimenting in the medical domain. I'm a tinkerer, but not an actual programmer, so apologies if this is a bad question...

Have you considered replicating this tutorial with the new models that include registers? I can't seem to find them available in HF, and I haven't yet been able to get it working appropriately by loading the model from Torch Hub. Seems like these could be quite promising for semantic segmentation — thanks!

NielsRogge commented 8 months ago

Hi, they haven't been added yet to HF: https://github.com/huggingface/transformers/issues/27379.

However this should be really easy given the tiny differences of https://github.com/facebookresearch/dinov2/pull/282/files.

tcourat commented 7 months ago

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise. If that is so, 1x1 convolution on unrolled patches is mathematically equivalent to linear classification over patch features

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)

You could directly apply a linear layer on a tensor (B,HW,D) instead of reshaping to (B,D,H,W) and using a 1x1 conv "trick" on it. Pytorch allows linear layer to take tensor with more than 2 shapes, provided that the last one corresponds to "in_features". It will then iterate the linear layer over each BHW tensor independently.

For clarification : B = batch dimension D = embedding dimension (e.g 1024 for large dino V2) H = feature map height ( image height//patch size , eg 32 for a model with patch size 16 and image size 512x512) W = feature map width ( image width//patch size ) Hence H*W is the total number of tokens.

antmedellin commented 3 months ago

Hi folks, Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb. DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

Thank you much for this tutorial! Having a ton of fun using it for experimenting in the medical domain. I'm a tinkerer, but not an actual programmer, so apologies if this is a bad question...

Have you considered replicating this tutorial with the new models that include registers? I can't seem to find them available in HF, and I haven't yet been able to get it working appropriately by loading the model from Torch Hub. Seems like these could be quite promising for semantic segmentation — thanks!

I created a sample notebook here that uses torch hub rather than hugging face for creating a custom semantic segmentation model. As a result, you can use the models with registers.

mzschwartz88 commented 3 months ago

awesome will check it out - thanks!!