huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.74k stars 25.53k forks source link

Whisper Translation on low resource languages #30592

Open RohitMidha23 opened 2 months ago

RohitMidha23 commented 2 months ago

System Info

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

I am training Whisper on the translate task as per this blog. Since I'm training it on a low resource language, my output's aren't that great. The basic language understanding isn't present.

My question is, will the following work:

  1. Train Whisper with transcribe task for that language. There is a lot more ASR data available.
  2. Subsequently fine-tune the model for the translate task.

Expected behavior

My understanding is that the encoder weights would be better after training for that language in specific so the output should be much better. Please correct me if my understanding is wrong!

RohitMidha23 commented 2 months ago

@sanchit-gandhi any suggestions on this?

sanchit-gandhi commented 1 month ago

Hey @RohitMidha23 - super sorry for the late reply here! Just a quick preface: we typically reserve GitHub issues for bug reports and feature requests for the Transformers library (e.g. feature X is broken with model Y), and the Hugging Face forum for questions regarding examples scripts and use-case applications (e.g. how do I fine-tune model X for task Y). Considering this, your question would be more appropriate for the forum! Just something to bear in mind to ensure you get help as quickly as possible, and so that the answer is maximally visible to other members of the community. We can discuss your question here for now, but I'd appreciate if you could copy and paste the final thread over to the forum!

Thanks for your issue description! Could I ask a few follow-up questions:

  1. How much translate and transcribe data do you have? (per task)
  2. Is this language part of the Whisper pre-training dataset?

I imagine you'd actually get best performance doing one of the following:

  1. Two rounds of fine-tuning: train on transcribe, then on translate (i.e. as you've proposed above)
  2. Train on the translate and transcribe datasets jointly in a single round of fine-tuning
  3. Train on translate, but supplement your small translate dataset with speech-translation data in other languages

These are all valid options for boosting your effective dataset size. As to which one will work best: that depends on how in-distribution your different datasets are, and how much of each one you have. The best thing is to try training using each approach and seeing which works best!

Related: @eustlb has been trying 3 for Distil-Whisper experiments, and has gotten promising first results (see the link for details).

RohitMidha23 commented 1 month ago

Understood, will keep it in mind next time.

Here is a detailed description of our data setup:

In Distribution Out of Distribution
Transcribe 70 hours 0
Translate 250 hours ~500 hours

Yes the language is a part of Whisper's pre-training dataset but extremely low resource.

sanchit-gandhi commented 1 month ago

Great - thanks for the clarification! Based on the above, I would give 1 and 2 a go to see if they give you a good starting point. You can then supplement with additional data in linguistically-related languages if you need more data

RohitMidha23 commented 1 month ago

@sanchit-gandhi tried out both methods 1 and 2. With 2, we see that even with the translate token being passed, the output contains transcribed words. Any known reasons for this?

With method 3, how much data would we need for a related language? Does training schedule play a part here?

sanchit-gandhi commented 1 month ago

Hey @RohitMidha23 - awesome to hear! Could you share the code that you're using for 2 so I can take a look?

The most related experiment I know of for 3 is when @eustlb mixed French with Spanish for distillation efforts: https://github.com/huggingface/distil-whisper/tree/main/training#3-language-mixing

Here, we used 500 hours of Spanish data to supplement 400 hours of French, and got -7.5% WER improvement. I would imagine a few hundred hours of closely related data should help you here as well

RohitMidha23 commented 1 month ago

Most of the code is picked up from your HuggingFace blog, @sanchit-gandhi.

Dataset Loading

from datasets import concatenate_datasets, load_dataset
translate_dataset = load_dataset("...")
transcribe_dataset = load_dataset("...")

dataset = concatenate_datasets(
        [
            translate_dataset["train"],
            transcribe_dataset["train"],
            transcribe_dataset["test"],
        ]
    )

dataset = dataset.sort("duration")

# No changes made here:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt"
        )

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Note: Here I'm simply combining the mapped datasets (mapped using different task specific tokenizers).

Processor

processor = WhisperProcessor.from_pretrained(MODEL_ID, task="translate")
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Model

model = WhisperForConditionalGeneration.from_pretrained(
        MODEL_ID,
        use_safetensors=True,
    )

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

Trainer

training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16, 
        learning_rate=1e-5,
        num_train_epochs=num_epochs,
        bf16=True,
        warmup_ratio=0.1,
        predict_with_generate=True,
        generation_max_length=444,
        save_strategy="steps",
        save_steps=500,
        logging_steps=15,
        report_to=["wandb"],
        push_to_hub=True,
        hub_strategy="all_checkpoints",
        save_total_limit=5,
    )

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

I guess there is one part that I'm unsure about which is:

if TASK == "transcribe":
        model.generation_config.language = "<|hi|>"  # from : https://github.com/huggingface/transformers/pull/28687#issuecomment-1970980372
else:
        model.generation_config.language = "<|en|>"

I'm not sure of how this can be specified during training cause we're mixing the dataset.

Pseudo Labeling

Another question I had, along the same lines, was if I use a model, say Gemini to label huge amounts of data - how would I use it with trainer. Do I need to write the training loop here myself to maybe change the loss for such samples or can that be specified somehow through trainer? The idea I have here is that, maybe I can pre-train or train whisper in the above methods, to learn very basics of the language before going on to the harder language present in my in distribution dataset. Thoughts?

sanchit-gandhi commented 3 weeks ago

Thanks for the code snippet @RohitMidha23! I'd be particularly interested in seeing your prepare_dataset function. Specifically, we need to make sure is that we switch the tokenizer task ids depending on whether the dataset sample is transcribe, or translate.

E.g. for speech transcription, we'll have a target transcription that looks like the following:

खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई

We then add the "special" task token ids that inform the Whisper model of the task it needs to perform, as well as the language with which to decode:

<|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>

=> notice here that for speech transcription, we insert the <|transcribe|> token at index 2.

For the same audio sample, the target transcription for speech translation looks like the following:

Bihar's politics heated up over the sweetness of Kheer, Kushwaha gave clarification

Which corresponds to the following sequence with the special task tokens:

<|startoftranscript|><|hi|><|translate|><|notimestamps|>Bihar's politics heated up over the sweetness of Kheer, Kushwaha gave clarification<|endoftext|>

=> the same audio sample has different targets for speech translation/transcription, which we inform to the model by pre-pending the corresponding task tokens. You can read more about this in Section 2.3 of the Whisper paper.

How does this look in practice? We set the tokenizer task tokens using the set_prefix_tokens method. For each sample, we want to set the tokens based on whether it's a speech transcription or translation example, and then tokenise the target text accordingly.

Let's assume that your dataset has three columns: audio, text and task, where the task column denotes whether a particular sample is speech transcription, or speech translation. You can then prepare your dataset as follows:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

+   # set the target language for the tokenizer
+   task = batch["task"]
+   tokenizer.set_prefix_tokens(language="hindi", task=task)

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

Adding this extra task column is trivial:

translate_dataset = load_dataset("...", split="train)
transcribe_dataset = load_dataset("...", split="train")

+ translate_dataset = translate_dataset.add_column("task", len(translate_dataset) * ["translate"])
+ transcribe_dataset = transcribe_dataset.add_column("task", len(transcribe_dataset) * ["transcribe"])

dataset = concatenate_datasets([translate_dataset, transcribe_dataset])
RohitMidha23 commented 1 week ago

@sanchit-gandhi fairly similar code in approach at least.

def prepare_dataset(batch, processor):
    audio = batch["audio"]

    # compute input length - for filtering
    batch["input_length"] = len(batch["audio"])

    batch["input_features"] = processor.feature_extractor(
        audio["array"], sampling_rate=audio["sampling_rate"]
    ).input_features[0]

    batch["labels"] = processor.tokenizer(
        batch["transcription"], max_length=444, truncation=True
    ).input_ids

    batch["labels_length"] = len(batch["labels"])
    return batch

I first mapped the ASR dataset with prepare_dataset function I had written for ASR (from the blog).

Then I map the translate dataset by passing a processor with the task as translate.

Following which, I concatenate the 2 datasets:

dataset = concatenate_datasets(
        [
            translate_dataset["train"],
            transcribe_dataset["train"],
            transcribe_dataset["test"],
        ]
    )

dataset = dataset.sort("duration")

The tokenizer and processor I use:

tokenizer = WhisperTokenizer.from_pretrained(MODEL_ID, task="translate")
processor = WhisperProcessor.from_pretrained(MODEL_ID, task="translate")

which are passed into the trainer:

trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.tokenizer,
    )