huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.97k stars 26.29k forks source link

Whisper generate return a slice of result if result have more than one added token #33082

Open txya900619 opened 3 weeks ago

txya900619 commented 3 weeks ago

System Info

Who can help?

@sanchit-gandhi / @kamilakesbi

Information

Tasks

Reproduction

I added some new tokens to the tokenizer, resized the word embedding, and then finetuned with the custom dataset.

from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
)
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-large-v3",
    language="en",
    task="transcribe", 
)
model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-large-v3",
)
processor.tokenizer.add_tokens([]) #fill list with tokens you want to add
model.resize_token_embeddings(len(processor.tokenizer))

# then finetune

When I use the generate function on a fine-tuned model, I find that if the model's output contains more than one newly added token, the output is truncated.

from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
)
processor = WhisperProcessor.from_pretrained(
    "", #model save path
    language="en",
    task="transcribe", 
)
model = WhisperForConditionalGeneration.from_pretrained(
    "", #model save path
)

eval_dataloader = #use your dataloader

for batch in tqdm(eval_dataloader):
    generated_tokens = (
        model.generate(
            input_features=batch["input_features"].cuda(),
            max_new_tokens=255,
        )
        .cpu()
        .numpy()
    )
print(processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))

I've found that this line of code should make it consider the new tokens to be timestamp_tokens, which causes it to truncate the result. https://github.com/huggingface/transformers/blob/d806fa3e92289876e01ab19c9e19e9264ea1c1a1/src/transformers/models/whisper/generation_whisper.py#L1759

Expected behavior

The generate function should output the full transcription.

gante commented 3 weeks ago

(passing along to @sanchit-gandhi / @kamilakesbi , as this is a Whisper-specific question :) )

txya900619 commented 3 weeks ago

(passing along to @sanchit-gandhi / @kamilakesbi , as this is a Whisper-specific question :) )

Okay, edited.

ArthurZucker commented 2 weeks ago

This seems related to #32378 as well! cc @ylacombe

ylacombe commented 1 week ago

Hey @txya900619, thanks for opening the issue! It'd be tremendously helpful if you could provide a fine-tuned version with some added tokens so that I can reproduce the error! Do you think it's possible?

ylacombe commented 1 week ago

Hey @txya900619,

I think you might be right about your diagnostic. However, another reason could be that the tokenizer simply skipped your tokens skip_special_tokens=True.

I've detailed a bit the tokenizer behaviour here.

Could you verify that if you do skip_special_tokens=False, the added tokens appear as expected?

txya900619 commented 1 week ago

Hey @txya900619, thanks for opening the issue! It'd be tremendously helpful if you could provide a fine-tuned version with some added tokens so that I can reproduce the error! Do you think it's possible?

After some testing, I found that this problem only occurs when two new tokens are next to each other. There is my testing code:

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AdamW,
    WhisperForConditionalGeneration,
    WhisperProcessor,
)

if __name__ == "__main__":
    data = load_dataset("openslr/librispeech_asr", split="test.clean", streaming=True)
    sample = next(iter(data))
    sample["text"] = sample["text"].lower() # text is concord returned to its place amidst the tents

    processor = WhisperProcessor.from_pretrained(
        "openai/whisper-tiny",
        language="en",
        task="transcribe",
    )
    model = WhisperForConditionalGeneration.from_pretrained(
        "openai/whisper-tiny",
    )

    processor.tokenizer.add_tokens(["concord", " returned"])
    model.resize_token_embeddings(len(processor.tokenizer))
    model.cuda()

    optimizer = AdamW(model.parameters(), lr=5e-6)

    print(sample["text"])
    print(processor.tokenizer.encode(sample["text"]))

    for i in tqdm(range(30)):
        inputs = processor(
            audio=sample["audio"]["array"], text=sample["text"], return_tensors="pt"
        )
        output = model(
            input_features=inputs["input_features"].cuda(),
            labels=inputs["labels"].cuda(),
        )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
    result = model.generate(
        input_features=inputs["input_features"].cuda(),
        language="en",
    )
    print(result)
    print(processor.batch_decode(result))

I printed out the results and found that the ids matched the tokenizer decoding results, so it was not the tokenizer's problem.

And If you print seek_sequences before https://github.com/huggingface/transformers/blob/21fac7abba2a37fae86106f87fcf9974fd1e3830/src/transformers/models/whisper/generation_whisper.py#L694 you get [tensor([50258, 50259, 50359, 50363, 51865, 51866, 281, 1080, 1081, 30153, 372, 264, 39283], device='cuda:0')]

Then after this for loop print current_segments and you will get [[{'start': tensor(-2.1200), 'end': tensor(30.0200), 'tokens': tensor([50258, 50259, 50359, 50363, 51865], device='cuda:0'), 'result': tensor([50258, 50259, 50359, 50363, 51865, 51866, 281,1080,1081, 30153, 372, 264, 39283, 50257], device='cuda:0')}]]

You can see that the tokens are cut off in the middle of 51865, 51866.

eustlb commented 1 day ago

Hey @txya900619, thanks for raising this issue.!

What's happening here is that _retrieve_segment identifies timestamp tokens simply considering they are tokens > timestamp_begin (see this line). Nevertheless, when adding tokens to the vocabulary, they are added at the end of it and their ids end up being > timestamp_begin, and thus considered as timestamp tokens. For this reason, the given example gets falsely split here and only the first segment is kept.

I see multiple two ways of solving this issue:

  1. Hardcoding the number of timestamp tokens. I know hardcoding is usually not welcomed, yet we know that the number of timestamps tokens is fixed to 1501 (from 0.0 to 30.0 with a step of 20ms) anyway so I don't see any reason for not doing it. Moreover, the only change necessary would be to modify this line :
    - timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
    + timestamp_tokens: torch.Tensor = (seek_sequence >= timestamp_begin) & (seek_sequence <= timestamp_begin + 1501)
  2. update generation_config.json with timestamp_begin_id and timestamp_end_id. This solution requires:
    1. modify generation_config.json of every whisper model
    2. modify _retrieve_segment. It already takes timestamp_begin_id, it would need also to take timestamp_end_id
    3. modify _set_return_timestamp to have it return timestamp_begin and timestamp_begin from the generation_config.

I would rather go for solution 1. I tested it and it solves the issue! WDYT @ylacombe ?

ylacombe commented 10 hours ago

Hey @eustlb, thanks for taking the time to look into it and to propose some fixes!

Hard-coding is generally something we want to avoid, as it's the best way to hide the code's inner mechanism from users and make debugging even more complicated. This is especially true since Whisper is used quite a lot.

What I'd recommend is probably a mix of both solution, since we can't actually modify the generation_config.json of all Whisper models. What we can do though, is to actually impose a default value that'll allow users to:

  1. in most of the use-cases, use this default value
  2. in the other complex use-cases, modify it in their generation config.

Timestamp begin is computed here. We could probably do something like: timestamp_end = timestamp_begin + generation_config.get("number_timestamp_tokens", DEFAULT_VALUE) ? And then use this timestamp_end where it's relevant ?

Something that I'd be careful about would be to document this somewhere, probably in the return_timestamps description, and to explain the default value both there and in comments.

How does it sound?

eustlb commented 6 hours ago

Sounds great! Thanks for the elegant suggestion. I am opening a PR that I'll ref here, just first making sure the required changes are propagated correctly (e.g. it seems that we need to also change WhisperTimeStampLogitsProcessor)