LAION-AI / CLAP

Contrastive Language-Audio Pretraining
https://arxiv.org/abs/2211.06687
Creative Commons Zero v1.0 Universal
1.43k stars 137 forks source link

Fix bug in hook.py for single text prompt #105

Closed amanteur closed 1 year ago

amanteur commented 1 year ago

Referring to #85.

Change return value in tokenizer method in hook.py.

When running model.get_text_embedding with list with a size = 1, it yields an error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 5
      3 with torch.no_grad():
      4     audio_embed = model.get_audio_embedding_from_filelist(x = [audio_file], use_tensor=True)
----> 5     text_embed = model.get_text_embedding([text], use_tensor=True)  
      7 torch.nn.functional.cosine_similarity(
      8     audio_embed,
      9     text_embed
     10 )

File ~/venv/lib/python3.9/site-packages/laion_clap/hook.py:214, in CLAP_Module.get_text_embedding(self, x, tokenizer, use_tensor)
    212 else:
    213     text_input = self.tokenizer(x)
--> 214 text_embed = self.model.get_text_embedding(text_input)
    215 if not use_tensor:
    216     text_embed = text_embed.detach().cpu().numpy()

File ~/venv/lib/python3.9/site-packages/laion_clap/clap_module/model.py:715, in CLAP.get_text_embedding(self, data)
    713 for k in data:
    714     data[k] = data[k].to(device)
--> 715 text_embeds = self.encode_text(data, device=device)
    716 text_embeds = F.normalize(text_embeds, dim=-1)
    718 return text_embeds

File ~/venv/lib/python3.9/site-packages/laion_clap/clap_module/model.py:630, in CLAP.encode_text(self, text, device)
    628     x = self.text_projection(x)
    629 elif self.text_branch_type == "roberta":
--> 630     x = self.text_branch(
    631         input_ids=text["input_ids"].to(device=device, non_blocking=True),
    632         attention_mask=text["attention_mask"].to(
    633             device=device, non_blocking=True
    634         ),
    635     )["pooler_output"]
    636     x = self.text_projection(x)
    637 elif self.text_branch_type == "bart":

File ~/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/venv/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py:806, in RobertaModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    803 else:
    804     raise ValueError("You have to specify either input_ids or inputs_embeds")
--> 806 batch_size, seq_length = input_shape
    807 device = input_ids.device if input_ids is not None else inputs_embeds.device
    809 # past_key_values_length

ValueError: not enough values to unpack (expected 2, got 1)

It happens because in hook.py the result is squeezed, so inputs with size [1, number_of_tokens] to [number_of_tokens].

RetroCirce commented 1 year ago

Thank you for your implementation!