huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.33k stars 26.86k forks source link

WhisperForCTC #26242

Open DavraYoung opened 1 year ago

DavraYoung commented 1 year ago

Feature request

Request to add WhisperForCTC model.

Motivation

it would be cool if we had custom WhisperForCTC with Whisper encoder and ctc head, just like Wav2vec2, but since whisper is based on mel spectograms, I think it may bring better results.

Your contribution

Here is my implementation, I mostly copied from Wav2vec2ForCTC NOTE: there is TODO that needs to be resolved, I didnt test that part, since whisper operates with transposed hidden_states shape


_HIDDEN_STATES_START_POSITION = 2

class ExtendedWhisperConfig(WhisperConfig):
    def __init__(
        self,
        ctc_loss_reduction: str = "mean",
        final_dropout: float = 0.0,
        ctc_zero_infinity: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.ctc_loss_reduction = ctc_loss_reduction
        self.final_dropout = final_dropout
        self.ctc_zero_infinity = ctc_zero_infinity

class WhisperEncoderForCTC(WhisperPreTrainedModel):
    config_class = ExtendedWhisperConfig

    def __init__(self, config):
        super().__init__(config)

        self.encoder = WhisperEncoder(config)
        self.dropout = nn.Dropout(config.final_dropout)
        if config.vocab_size is None:
            raise ValueError(
                f"You are trying to instantiate {self.__class__} with a configuration that "
                "does not define the vocabulary size of the language model head. Please "
                "instantiate the model as follows: `WhisperEncoderForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
                "or define `vocab_size` of your model's configuration."
            )
        output_hidden_size = (
            config.output_hidden_size
            if hasattr(config, "add_adapter") and config.add_adapter
            else config.hidden_size
        )
        self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)

        # Initialize weights and apply final processing
        self.post_init()

    def freeze_base_model(self):
        """
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        """
        for param in self.encoder.parameters():
            param.requires_grad = False

    def forward(
        self,
        input_features: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, CausalLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        """

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.encoder(
            input_features,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        hidden_states = self.dropout(hidden_states)

        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            if labels.max() >= self.config.vocab_size:
                raise ValueError(
                    f"Label values must be <= vocab_size: {self.config.vocab_size}"
                )

            attention_mask = (
                attention_mask
                if attention_mask is not None
                else torch.ones_like(input_features.transpose(1, 2), dtype=torch.long)
            )
            # TODO: check if this is correct
            input_lengths = self._get_feat_extract_output_lengths(
                attention_mask.sum(-1)
            ).to(torch.long)

            # assuming that padded tokens are filled with -100
            # when not being attended to
            labels_mask = labels >= 0
            target_lengths = labels_mask.sum(-1)
            flattened_targets = labels.masked_select(labels_mask)

            # ctc_loss doesn't support fp16
            log_probs = nn.functional.log_softmax(
                logits, dim=-1, dtype=torch.float32
            ).transpose(0, 1)
            with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,
                    zero_infinity=self.config.ctc_zero_infinity,
                )

        if not return_dict:
            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
DavraYoung commented 1 year ago

Regarding the accuracy and if the model actually work: I was able to achieve decent accuracy (WER: 7% on my test dataset with only 10% of my dataset(100k audios), starting from WhisperForConditionalGeneration Encoder checkpoint).

ArthurZucker commented 1 year ago

Adding as a feature request FYI @sanchit-gandhi

sanchit-gandhi commented 1 year ago

That's very cool @DavraYoung! Did you build the CTC tokenizer yourself as well? And how did WhisperForCTC compare to WhisperForConditionalGeneration when fine-tuned on your dataset? We could run very fast training by freezing the entire encoder block and only fine-tuning the CTC head šŸ‘€

DavraYoung commented 1 year ago

@sanchit-gandhi hi, Regarding tokenizer: I used wav2vec2 tokenizer with custom vocab(Latin lowercase alphabet + ' , like in wav2vec2 finetuning tutorial.

Regarding performance: I cannot directly compare the models right now, since I trained WhisperForConditionalGeneration on slightly different dataset(some entries are not present) and some entries from my current validation dataset were present in training data of WhisperForConditionalGeneration.

Regarding the actual performance on the unseen dataset, I think, WhisperForConditionalGeneration is much better than any CTC based model, especially when given previous context/prompt, but it requires good dataset with long audios with enough previous text context. CTC head based models on the other hand does not require diverse lengths dataset and may operate on smaller audios, like in my current dataset. Thats why I was investigating CTC head with WhisperEncoder

If its needed I can spend some time on training whisper-ctc-medium-960h librispeech English model with Wav2vec2-base-960h tokenizer vocab.

DavraYoung commented 1 year ago

After playing around with such model, I found issues with degraded performance on validation dataset.

My training setup: openai/whisper-large encoder + new ctc head. 400 hours of uzbek short audios. Modified encoder with partial positional encodings

I couldnt make freezed version to converge. So I switched to unfreezed version training. Model was showing good results on 128 Testing samples. After training for 2 epochs I ran the model on validation dataset and unfortunately it has shown worse results that wav2vec2forCtc on the same 2048 clips dataset.

I modified the positional embeddings in encoder:

inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
# reduce embed_pos to the same shape as inputs_embeds
embed_pos = embed_pos[: inputs_embeds.shape[1], :]

probably they could cause the issue.

During training I observe overall good loss: 0.03-0.07 (3-7%), but I repeatedly see loss jumping to 20-40%. Checked audios with that sample entries, they are fine. It seems like model fails to understand certain features.

Here is the list of my models trained on the same dataset(except phone versions). Phone versions were trained on top of general model with random audio sample rate reduction + 10 hours of real phone data Model Name Average CER Average WER
general-wav2vec2-1b-overfited-11 0.004 0.029
general-wav2vec2-2b-10-07-2023 0.008 0.054
general-medium-wav2vec2-conformer 0.014 0.091
mixed-wav2vec-2b-01-09-2023 0.014 0.086
general-wav2vec2-medium-14-09-2023 0.023 0.139
general-wav2vec2-small-13-09-2023 0.056 0.305
general-wav2vec2-100m-11-09-2023 0.169 0.723
general-whisper-large-ctc-18-09-2023 0.178 0.197
phone-whisper-large-ctc-18-09-2023 0.187 0.249
general-whisper-medium-ctc-18-09-2023 0.191 0.265
general-whisper-ultra-ctc-18-09-2023 0.19 0.26
sanchit-gandhi commented 1 year ago

Hey @DavraYoung! Thanks for explaining a bit more about how you're using this model. My only concern with adding this to the Transformers' library is that it's a bit of an 'un-official' implementation? That is to say, there are no official pre-trained weights available for the model.

Having pre-trained weights and reference code are general pre-requisites for adding a model to Transformers. My view is that Whisper for CTC is a nice extension of the Whisper Transformers code to a different modelling paradigm. However, it's not one that is natively supported in the original Whisper library, or contains official weights on the Hub. Hence, it could be a nice example that you share in a standalone repo of your own? Or showcased on the Hub with your fine-tuned weights?

Regarding CTC vs Enc-Dec, there are some nice comparisons in the ESB paper: https://arxiv.org/abs/2210.13352

Notably, CTC performs worse than Enc-Dec with punctuation and casing, and makes more frequent spelling errors. It's these aspects of robustness that make Whisper Enc-Dec such an appealing model for ASR, so I think promoting the Enc-Dec architecture is sensible here.

It's hard to comment on your training results without any information / statistics on the training data or task, but they seem to suggest that Wav2Vec2 is a more promising approach for encoder only CTC decoding.

cjw414 commented 11 months ago

@DavraYoung Thanks for sharing great work! can I use utilize your codes to compare with other CTC based models, like HuBERT / WavLM / XLS-R / MMS ?

as @sanchit-gandhi mentioned, comparing the results with other CTC based models will give the WhisperEncoderForCTC a fairer comparison!

DavraYoung commented 11 months ago

@cjw414 yes, no problem. I would also recommend to change the encoder embeddings to work properly with 'longest' padding.

# reduce embed_pos to the same shape as inputs_embeds
embed_pos = embed_pos[: inputs_embeds.shape[1], :]

https://github.com/huggingface/transformers/blob/e469be340673d1f6931eb22562efd2be7f5a5b8d/src/transformers/models/whisper/modeling_whisper.py#L902

Otherwise you will need to use WhisperFeatureExtractor with padded audio length to 30s, which may impact training speed, if your average audio length is short.

cjw414 commented 11 months ago

thx, but I had my personal experience of Whisper not functioning well if it given short audios without padding. maybe I did something wrong, but I might just start with paddding with 30s

Will notice you if I get some findings!

sanchit-gandhi commented 10 months ago

I also found that in my Distil-Whisper experiments padding to 30s worked better than non-padding! Likely because the model is pre-trained on such a vast amount of data that it ends up working so strongly on padded inputs

cjw414 commented 9 months ago

I also found that in my Distil-Whisper experiments padding to 30s worked better than non-padding! Likely because the model is pre-trained on such a vast amount of data that it ends up working so strongly on padded inputs

yup probably this might be one of the obstacles for Whisper from slow inference speed btw, I really liked your distil-whisper!!