lucidrains / x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
MIT License
666 stars 47 forks source link

Using different encoders in CLIP #12

Closed ethancohen123 closed 1 year ago

ethancohen123 commented 1 year ago

Hi, I am wondering if it was possible to use different encoders in CLIP ? For images not using vit but resnet for example. And is it possible to replace the text encoder by a features encoder for example ? If I have a vector of features for a given image and I want to use x-clip how should I do that ? I have made a code example that doesnt seems to work, here is what I did:

import torch
from x_clip import CLIP
import torch.nn as nn
from torchvision import models

class Image_Encoder(torch.nn.Module):
    #output size is (bs,512)
    def __init__(self):
        super(Image_Encoder, self).__init__()
        self.model_pre = models.resnet18(pretrained=False)
        self.base=nn.Sequential(*list(self.model_pre.children()))
        self.base[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet=self.base[:-1]

    def forward(self, x):
        out=self.resnet(x).squeeze()
        return out

class features_encoder(torch.nn.Module):
    #output size is (bs,512)
    def __init__(self):
        super(features_encoder, self).__init__()
        self.model =nn.Linear(2048,512)

    def forward(self, x):
        out=self.model(x)
        return out

images_encoder=Image_Encoder()
features_encoder=features_encoder()

clip = CLIP(
    image_encoder = images_encoder,
    text_encoder = features_encoder,
    dim_image = 512,
    dim_text = 512,
    dim_latent = 512
)

features= torch.randn(4,2048)
images = torch.randn(4, 3, 256, 256)

loss = clip(features, images, return_loss = True)
loss.backward()

but I got the following error : forward() takes 2 positional arguments but 3 were given

Thanks

MicPie commented 1 year ago

Hi Ethan, (lets move our entire discussion about this over here. :-) )

I guess this error comes from this line from the text_encoder: https://github.com/lucidrains/x-clip/blob/main/x_clip/x_clip.py#L250 You could simply solve by using a mask kwarg: def forward(self, x, mask): or use **kwargs: def forward(self, x, **kwargs):

lucidrains commented 1 year ago

@MicPie Hi Michael :wave: :laughing:

@ethancohen123 Ethan, try the following script

import torch
from x_clip import CLIP
import torch.nn as nn
from torchvision import models
from einops import rearrange

class Image_Encoder(torch.nn.Module):
    #output size is (bs,512)
    def __init__(self):
        super(Image_Encoder, self).__init__()
        self.model_pre = models.resnet18(pretrained=False)
        self.base=nn.Sequential(*list(self.model_pre.children()))
        self.base[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet=self.base[:-2]

    def forward(self, x):
        out=self.resnet(x)
        out=rearrange(out, 'b d h w -> b (h w) d')
        return out

class features_encoder(torch.nn.Module):
    #output size is (bs,512)
    def __init__(self):
        super(features_encoder, self).__init__()
        self.model =nn.Linear(1,512)

    def forward(self, x):
        out=self.model(x)
        return out

images_encoder=Image_Encoder()
features_encoder=features_encoder()

clip = CLIP(
    image_encoder = images_encoder,
    text_encoder = features_encoder,
    text_encode_without_mask = True,
    dim_image = 512,
    dim_text = 512,
    dim_latent = 512
)

features= torch.randn(4,2048, 1)
images = torch.randn(4, 3, 256, 256)

loss = clip(features, images, return_loss = True)
loss.backward()
MicPie commented 1 year ago

@lucidrains Hi Phil! 👋 Nice to read you again after a while. :-)

Ah, text_encode_without_mask = True is much nicer - thank you for the help!

ethancohen123 commented 1 year ago

Works just fine thank you ! If I understand correctly from your code, there is an additional dimension in the embedding from the network because your calling this in the code (row 619 in x-clip.py) is that it ?

if self.use_all_token_embeds:
          text_embeds = enc_text[:, 1:] if self.text_has_cls_token else enc_text
          image_embeds = enc_image[:, 1:] if self.visual_has_cls_token else enc_image
      else:
          text_embeds = enc_text[:, 0]
          image_embeds = enc_image[:, 0]

Thanks !

lucidrains commented 1 year ago

@ethancohen123 oh apologies, your original way should work too, provided you aren't using all of the token embeddings (the scheme from FILIP)! https://github.com/lucidrains/x-clip/commit/223e1a0286c5678048e6d052b167b56a6a4bb371

ethancohen123 commented 1 year ago

works just fine thanks :) Although now visual ssl set to True returns the following error :
EinopsError: Error while processing rearrange-reduction pattern "b n d -> (b n) d". Input tensor shape: torch.Size([2, 512]). Additional info: {}. Expected 3 dimensions, got 2

Sorry about the trouble aha

lucidrains commented 1 year ago

@ethancohen123 no problem, this is helpful! https://github.com/lucidrains/x-clip/commit/e38b5107675efd19e92aa4db9da5d027ab130a97

ethancohen123 commented 1 year ago

works just fine thanks !