isl-org / lang-seg

Language-Driven Semantic Segmentation
MIT License
722 stars 91 forks source link

How to get pixel-level embeddings #33

Closed loris2222 closed 2 years ago

loris2222 commented 2 years ago

Hi, I am trying to use your model for research purposes on Explainable AI. After struggling for more than I'd like to admit I finally managed to get it up and working, however, I can't find an easy way to get the pixel-level embeddings from your framework, since the interfaces are quite convoluted.

Right now I've been able to do so with evaluator._modules['module'].net.get_image_features(image) starting from your notebook. I had to write get_image_features as a modified version of forward that ends at the image features. As such, I don't think this is the best way.

Do you have any suggestion on how to proceed? Maybe some general instructions on how to try to do so?

Thank you in advance!

Boyiliee commented 2 years ago

Hi @loris2222 ,

Thanks for your interest in LSeg!

Yea, LSeg was built one year ago, so you might need to install older version tools. I guess the fast solution is to revise the code.

Hope this helps!

Best, Boyi

yhyang-myron commented 1 year ago

Hi @loris2222, Have you found the way to get pixel-level embeddings in newer version tools?

loris2222 commented 1 year ago

Hi @yhyang-myron, yes, I was able to get pixel-level embeddings but the solution is quite hacky and I am not ready to share the full code. I am unsure about what you mean with 'newer version tools' though.

By the way, to get it working, I had to add a function in lseg_net.py that returns the model output after the embeddings. This, however, seems to only be working for batch size = 1 since it must be run through the evaluator. Anyway, here is the code for that function:

def get_image_features(self, x):
    if self.channels_last == True:
        x.contiguous(memory_format=torch.channels_last)
    layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)

    layer_1_rn = self.scratch.layer1_rn(layer_1)
    layer_2_rn = self.scratch.layer2_rn(layer_2)
    layer_3_rn = self.scratch.layer3_rn(layer_3)
    layer_4_rn = self.scratch.layer4_rn(layer_4)

    path_4 = self.scratch.refinenet4(layer_4_rn)
    path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
    path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
    path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

    image_features = self.scratch.head1(path_1)
    imshape = image_features.shape
    image_features = image_features.permute(0, 2, 3, 1).reshape(-1, self.out_c)

    image_features = image_features.view(imshape[0], imshape[2], imshape[3], -1)
    image_features = torch.nn.functional.normalize(image_features, p=2.0, dim=-1) 

    return image_features
yhyang-myron commented 1 year ago

@loris2222 Thank you!