huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.27k stars 26.1k forks source link

`last_hidden_state` has a different shape than `hidden_states[-1]` in the output of `SeamlessM4Tv2SpeechEncoder` if adapter layers are present #31946

Open anferico opened 1 month ago

anferico commented 1 month ago

System Info

Who can help?

@sanchit-gandhi @ylacombe

Information

Tasks

Reproduction

  1. Create an instance of SeamlessM4Tv2SpeechEncoder with 1 or more adapter layer(s) having stride > 1, for example by doing:

    from transformers import AutoModel
    
    speech_encoder = AutoModel.from_pretrained("facebook/seamless-m4t-v2-large").speech_encoder
  2. Encode a sample audio and pass it through the speech encoder:

    import torch
    from transformers import AutoProcessor
    
    audio_processor = AutoProcessor.from_pretrained("meetween/seamless-m4t-v2-large-speech-encoder")
    
    audio, sr = ...  # load an audio somehow and make it a torch.Tensor
    inputs = audio_processor(
        audios=audio.squeeze().float().cpu(),
        sampling_rate=sr,
        return_tensors="pt",
    )
    audio_features = model(**inputs, output_attentions=True, output_hidden_states=True)
  3. Access the resulting output and notice how the shape of last_hidden_state is different than the shape of hidden_states[-1]:
    assert audio_features.last_hidden_state.shape != audio_features.hidden_states[-1].shape
  4. Similarly, notice how the shape of last_hidden_state is not compatible with the shape of attentions[-1]:
    batch_size, seq_len_1, emb_size = audio_features.last_hidden_state.shape
    batch_size, num_heads, seq_len_2, seq_len_2 = attentions[-1].shape
    assert seq_len_1 != seq_len_2

Expected behavior

assert audio_features.last_hidden_state.shape == audio_features.hidden_states[-1].shape
assert seq_len_1 == seq_len_2

Why this is a problem (in my view)

  1. Misleading names: last_hidden_state is different than hidden_states[-1]
  2. Consider the following use case: a pre-trained instance of SeamlessM4Tv2SpeechEncoder is used as a speech encoder in a model architecture used for ASR, the full model architecture being speech encoder + custom text decoder. If we train this model with batch size > 1, the speech encoder will be fed padded audio sequences. As a result, when feeding encoded audio sequences (output of the speech encoder) to the custom text decoder, we have to construct a proper attention_mask to make sure padded positions are treated as such. Normally, to do this, we would take speech_encoder_output.attentions from the speech encoder output, convert them to an attention_mask by looking at which elements are > 0 (i.e. which positions in the sequence have an attention weight > 0), then apply the obtained attention_mask to speech_encoder_output.last_hidden_states. However, this cannot be done since, as mentioned above, seq_len_1 != seq_len_2

Because of 2), the only way to apply attention_mask to speech_encoder_output.last_hidden_states is to manually figure out the correct shape of attention_mask by considering how many convolutional layers are present in speech_encoder.adapter (instance of SeamlessM4Tv2ConformerAdapter) and what their padding, dilation, kernel_size and stride parameters are, then compute the output length (seq_len_1) as a function of the input length (seq_len_2) as:

len_out = math.floor((len_in + 2*padding - dilation*(kernel_size - 1) - 1) / stride + 1)

Proposed workaround

Instead of doing this at the end of SeamlessM4Tv2SpeechEncoder.forward():

return Wav2Vec2BaseModelOutput(
    last_hidden_state=hidden_states,
    hidden_states=encoder_outputs.hidden_states,
    attentions=encoder_outputs.attentions,
)

do something like:

return SomeNewTypeOfModelOutput(
    last_hidden_state=hidden_states,
    hidden_states=encoder_outputs.hidden_states,
    attentions=encoder_outputs.attentions,
    last_adapter_state=...,
    adapter_states=...,
    adapter_attentions=..,
)
ylacombe commented 1 month ago

Hey @anferico, thanks for opening this issue and for the thorough explanations!

First, please note that the speech encoder of M4T v2 has been open-sourced, and that we added it as a separate model in transformers, for easier handling and training:

Note that the attention mask is downsampled when passed through the adapter layers:

That said, you're indeed correct in saying that last_hidden_state is different than hidden_states[-1] if there is an adapter. I'm not entirely convinced this is an issue here though, as, if I remember correctly there are a few cases in transformers for which this also happens (especially when there are operations that happen after the transformers layers).

What could be interesting, though, is to pass the downsampled attention mask in the output of the speech encoder of M4T v2 and W2V2-BERT, to avoid recomputing the attention mask every time, what do you think of this ?

Of course, I'm open to discuss this if you have compelling reasons! Hope that it helps!

anferico commented 1 month ago

Thanks for looking into this @ylacombe! Regarding the speech encoder being released as a separate model, I was wondering:

Regarding your second point, that's exactly what I'm proposing. The main problem for me is that I have to manually compute the attention mask every time I run a forward pass of the speech encoder, which is not ideal. So would you be in favor of defining a new ModelOutput type (maybe one for M4Tv2 and one for W2V2-BERT) that includes the downsampled attention mask too?

ylacombe commented 1 month ago

So would you be in favor of defining a new ModelOutput type (maybe one for M4Tv2 and one for W2V2-BERT) that includes the downsampled attention mask too?

Precisely! Don't hesitate to ping me when you do it!

The way I understand it, facebook/w2v-bert-2.0 is the pre-trained checkpoint used in SeamlessM4Tv2 to initiate the speech encoder. It's indeed without adapters, because they were added during the M4T training.

Note that one big difference is that the M4T speech encoder is NC licensed, whereas the license of W2V is much more permissive. That's also why we pushed the latter checkpoint rather than the former.

It's really interesting to see the difference in performance though. Out of curiosity, how large is the gap between the two models (and for what languages)? Also, would you be interesting in finding ways to bridge this gap?

anferico commented 1 month ago

Sure, will update you once I'm done. Now I see the difference between the two, thank you for the insight 👍🏼

I haven't really measured properly the performance gap between w2v-bert-2.0 and SeamlessM4Tv2. What I can tell you is that on an English ASR task (LibriTTS), the architecture based on w2v-bert-2.0 simply didn't converge within a reasonable time, whereas the one based on SeamlessM4Tv2 did.

Also, would you be interesting in finding ways to bridge this gap?

Definitely, but what do you mean exactly? I'm open to any sort of collaborations if that's what you mean 😀

ylacombe commented 1 month ago

What I can tell you is that on an English ASR task (LibriTTS), the architecture based on w2v-bert-2.0 simply didn't converge within a reasonable time, whereas the one based on SeamlessM4Tv2 did.

Hey @anferico, this is weird, have you made sure to add an adapter layer ? See how it's done in the blog post:

from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)
anferico commented 1 month ago

I kinda did, but mine was a custom setting (I did not use Wav2Vec2BertForCTC). In particular, I kept the speech encoder (w2v-bert/Seamless) frozen and trained only a custom adapter followed by a text decoder.

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

anferico commented 2 weeks ago

Commenting to keep this alive until I find some time to work on it 😞