salesforce / BLIP

PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
BSD 3-Clause "New" or "Revised" License
4.85k stars 648 forks source link

Image-text Retrieval demo not working #125

Closed shivangibithel closed 1 year ago

shivangibithel commented 1 year ago

Hi @LiJunnan1992

The BLIP work is amazing. I was trying out the demo of Image-Text retrieval and gave me the following error. Can you please tell me what is wrong here? The code is from demo.ipynb

from models.blip_itm import blip_itm

image_size = 384 image = load_demo_image(image_size=image_size, device=device)

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'

model = blip_itm(pretrained=model_url, image_size=image_size, vit='base') model.eval() model = model.to(device='cpu')

caption = 'a woman sitting on the beach with a dog'

print('text: %s' %caption)

itm_output = model(image,caption,match_head='itm') itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1] print('The image and text is matched with a probability of %.4f'%itm_score)

itc_score = model(image,caption,match_head='itc') print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)

Error: RuntimeError Traceback (most recent call last) Cell In [7], line 1 ----> 1 itm_output = model(image,caption,match_head='itm') 2 itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1] 3 print('The image and text is matched with a probability of %.4f'%itm_score)

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, *kwargs) 1186 # If we don't have any hooks, we want to skip the rest of the logic in 1187 # this function, and just call forward. 1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1189 or _global_forward_hooks or _global_forward_pre_hooks): -> 1190 return forward_call(input, **kwargs) 1191 # Do not call functions when jit is used 1192 full_backward_hooks, non_full_backward_hooks = [], []

File /DATA/shivangib/BLIP/models/blip_itm.py:43, in BLIP_ITM.forward(self, image, caption, match_head) 41 def forward(self, image, caption, match_head='itm'): ---> 43 image_embeds = self.visual_encoder(image) 44 image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46 text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 return_tensors="pt").to(image.device)

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, *kwargs) 1186 # If we don't have any hooks, we want to skip the rest of the logic in 1187 # this function, and just call forward. 1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1189 or _global_forward_hooks or _global_forward_pre_hooks): -> 1190 return forward_call(input, **kwargs) 1191 # Do not call functions when jit is used 1192 full_backward_hooks, non_full_backward_hooks = [], []

File /DATA/shivangib/sigir/BLIP/models/vit.py:182, in VisionTransformer.forward(self, x, register_blk) 180 def forward(self, x, register_blk=-1): 181 B = x.shape[0] --> 182 x = self.patch_embed(x) 184 cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 x = torch.cat((cls_tokens, x), dim=1)

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, *kwargs) 1186 # If we don't have any hooks, we want to skip the rest of the logic in 1187 # this function, and just call forward. 1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1189 or _global_forward_hooks or _global_forward_pre_hooks): -> 1190 return forward_call(input, **kwargs) 1191 # Do not call functions when jit is used 1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/timm/models/layers/patch_embed.py:35, in PatchEmbed.forward(self, x) 32 B, C, H, W = x.shape 33 assert H == self.img_size[0] and W == self.img_size[1], \ 34 f"Input image size ({H}{W}) doesn't match model ({self.img_size[0]}{self.img_size[1]})." ---> 35 x = self.proj(x) 36 if self.flatten: 37 x = x.flatten(2).transpose(1, 2) # BCHW -> BNC

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, *kwargs) 1186 # If we don't have any hooks, we want to skip the rest of the logic in 1187 # this function, and just call forward. 1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1189 or _global_forward_hooks or _global_forward_pre_hooks): -> 1190 return forward_call(input, **kwargs) 1191 # Do not call functions when jit is used 1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input) 462 def forward(self, input: Tensor) -> Tensor: --> 463 return self._conv_forward(input, self.weight, self.bias)

File ~/miniconda3/envs/blip/lib/python3.9/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias) 455 if self.padding_mode != 'zeros': 456 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 457 weight, bias, self.stride, 458 _pair(0), self.dilation, self.groups) --> 459 return F.conv2d(input, weight, bias, self.stride, 460 self.padding, self.dilation, self.groups)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

LiJunnan1992 commented 1 year ago

Hi, thanks for your interest. Could you checkout our LAVIS library for image-text retrieval? https://github.com/salesforce/LAVIS

We have not finetuned a captioning model on Flickr, but you can find the code on LAVIS to do so yourself.