Closed ethancohen123 closed 1 year ago
Hi Ethan, (lets move our entire discussion about this over here. :-) )
I guess this error comes from this line from the text_encoder
:
https://github.com/lucidrains/x-clip/blob/main/x_clip/x_clip.py#L250
You could simply solve by using a mask kwarg:
def forward(self, x, mask):
or use **kwargs
:
def forward(self, x, **kwargs):
@MicPie Hi Michael :wave: :laughing:
@ethancohen123 Ethan, try the following script
import torch
from x_clip import CLIP
import torch.nn as nn
from torchvision import models
from einops import rearrange
class Image_Encoder(torch.nn.Module):
#output size is (bs,512)
def __init__(self):
super(Image_Encoder, self).__init__()
self.model_pre = models.resnet18(pretrained=False)
self.base=nn.Sequential(*list(self.model_pre.children()))
self.base[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.resnet=self.base[:-2]
def forward(self, x):
out=self.resnet(x)
out=rearrange(out, 'b d h w -> b (h w) d')
return out
class features_encoder(torch.nn.Module):
#output size is (bs,512)
def __init__(self):
super(features_encoder, self).__init__()
self.model =nn.Linear(1,512)
def forward(self, x):
out=self.model(x)
return out
images_encoder=Image_Encoder()
features_encoder=features_encoder()
clip = CLIP(
image_encoder = images_encoder,
text_encoder = features_encoder,
text_encode_without_mask = True,
dim_image = 512,
dim_text = 512,
dim_latent = 512
)
features= torch.randn(4,2048, 1)
images = torch.randn(4, 3, 256, 256)
loss = clip(features, images, return_loss = True)
loss.backward()
@lucidrains Hi Phil! 👋 Nice to read you again after a while. :-)
Ah, text_encode_without_mask = True
is much nicer - thank you for the help!
Works just fine thank you ! If I understand correctly from your code, there is an additional dimension in the embedding from the network because your calling this in the code (row 619 in x-clip.py) is that it ?
if self.use_all_token_embeds:
text_embeds = enc_text[:, 1:] if self.text_has_cls_token else enc_text
image_embeds = enc_image[:, 1:] if self.visual_has_cls_token else enc_image
else:
text_embeds = enc_text[:, 0]
image_embeds = enc_image[:, 0]
Thanks !
@ethancohen123 oh apologies, your original way should work too, provided you aren't using all of the token embeddings (the scheme from FILIP)! https://github.com/lucidrains/x-clip/commit/223e1a0286c5678048e6d052b167b56a6a4bb371
works just fine thanks :)
Although now visual ssl set to True returns the following error :
EinopsError: Error while processing rearrange-reduction pattern "b n d -> (b n) d".
Input tensor shape: torch.Size([2, 512]). Additional info: {}.
Expected 3 dimensions, got 2
Sorry about the trouble aha
@ethancohen123 no problem, this is helpful! https://github.com/lucidrains/x-clip/commit/e38b5107675efd19e92aa4db9da5d027ab130a97
works just fine thanks !
Hi, I am wondering if it was possible to use different encoders in CLIP ? For images not using vit but resnet for example. And is it possible to replace the text encoder by a features encoder for example ? If I have a vector of features for a given image and I want to use x-clip how should I do that ? I have made a code example that doesnt seems to work, here is what I did:
but I got the following error : forward() takes 2 positional arguments but 3 were given
Thanks