Open txya900619 opened 3 weeks ago
(passing along to @sanchit-gandhi / @kamilakesbi , as this is a Whisper-specific question :) )
(passing along to @sanchit-gandhi / @kamilakesbi , as this is a Whisper-specific question :) )
Okay, edited.
This seems related to #32378 as well! cc @ylacombe
Hey @txya900619, thanks for opening the issue! It'd be tremendously helpful if you could provide a fine-tuned version with some added tokens so that I can reproduce the error! Do you think it's possible?
Hey @txya900619,
I think you might be right about your diagnostic. However, another reason could be that the tokenizer simply skipped your tokens skip_special_tokens=True
.
I've detailed a bit the tokenizer behaviour here.
Could you verify that if you do skip_special_tokens=False
, the added tokens appear as expected?
Hey @txya900619, thanks for opening the issue! It'd be tremendously helpful if you could provide a fine-tuned version with some added tokens so that I can reproduce the error! Do you think it's possible?
After some testing, I found that this problem only occurs when two new tokens are next to each other. There is my testing code:
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
AdamW,
WhisperForConditionalGeneration,
WhisperProcessor,
)
if __name__ == "__main__":
data = load_dataset("openslr/librispeech_asr", split="test.clean", streaming=True)
sample = next(iter(data))
sample["text"] = sample["text"].lower() # text is concord returned to its place amidst the tents
processor = WhisperProcessor.from_pretrained(
"openai/whisper-tiny",
language="en",
task="transcribe",
)
model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny",
)
processor.tokenizer.add_tokens(["concord", " returned"])
model.resize_token_embeddings(len(processor.tokenizer))
model.cuda()
optimizer = AdamW(model.parameters(), lr=5e-6)
print(sample["text"])
print(processor.tokenizer.encode(sample["text"]))
for i in tqdm(range(30)):
inputs = processor(
audio=sample["audio"]["array"], text=sample["text"], return_tensors="pt"
)
output = model(
input_features=inputs["input_features"].cuda(),
labels=inputs["labels"].cuda(),
)
loss = output["loss"]
loss.backward()
optimizer.step()
result = model.generate(
input_features=inputs["input_features"].cuda(),
language="en",
)
print(result)
print(processor.batch_decode(result))
I printed out the results and found that the ids matched the tokenizer decoding results, so it was not the tokenizer's problem.
And If you print seek_sequences
before
https://github.com/huggingface/transformers/blob/21fac7abba2a37fae86106f87fcf9974fd1e3830/src/transformers/models/whisper/generation_whisper.py#L694
you get [tensor([50258, 50259, 50359, 50363, 51865, 51866, 281, 1080, 1081, 30153, 372, 264, 39283], device='cuda:0')]
Then after this for loop print current_segments
and you will get [[{'start': tensor(-2.1200), 'end': tensor(30.0200), 'tokens': tensor([50258, 50259, 50359, 50363, 51865], device='cuda:0'), 'result': tensor([50258, 50259, 50359, 50363, 51865, 51866, 281,1080,1081, 30153, 372, 264, 39283, 50257], device='cuda:0')}]]
You can see that the tokens are cut off in the middle of 51865, 51866.
Hey @txya900619, thanks for raising this issue.!
What's happening here is that _retrieve_segment
identifies timestamp tokens simply considering they are tokens > timestamp_begin
(see this line
). Nevertheless, when adding tokens to the vocabulary, they are added at the end of it and their ids end up being > timestamp_begin
, and thus considered as timestamp tokens. For this reason, the given example gets falsely split here and only the first segment is kept.
I see multiple two ways of solving this issue:
- timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
+ timestamp_tokens: torch.Tensor = (seek_sequence >= timestamp_begin) & (seek_sequence <= timestamp_begin + 1501)
timestamp_begin_id
and timestamp_end_id
. This solution requires:
_retrieve_segment
. It already takes timestamp_begin_id
, it would need also to take timestamp_end_id
_set_return_timestamp
to have it return timestamp_begin
and timestamp_begin
from the generation_config.I would rather go for solution 1. I tested it and it solves the issue! WDYT @ylacombe ?
Hey @eustlb, thanks for taking the time to look into it and to propose some fixes!
Hard-coding is generally something we want to avoid, as it's the best way to hide the code's inner mechanism from users and make debugging even more complicated. This is especially true since Whisper is used quite a lot.
What I'd recommend is probably a mix of both solution, since we can't actually modify the generation_config.json
of all Whisper models. What we can do though, is to actually impose a default value that'll allow users to:
Timestamp begin is computed here.
We could probably do something like: timestamp_end = timestamp_begin + generation_config.get("number_timestamp_tokens", DEFAULT_VALUE)
? And then use this timestamp_end
where it's relevant ?
Something that I'd be careful about would be to document this somewhere, probably in the return_timestamps
description, and to explain the default value both there and in comments.
How does it sound?
Sounds great! Thanks for the elegant suggestion. I am opening a PR that I'll ref here, just first making sure the required changes are propagated correctly (e.g. it seems that we need to also change WhisperTimeStampLogitsProcessor
)
System Info
transformers
version: 4.44.0Who can help?
@sanchit-gandhi / @kamilakesbi
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I added some new tokens to the tokenizer, resized the word embedding, and then finetuned with the custom dataset.
When I use the generate function on a fine-tuned model, I find that if the model's output contains more than one newly added token, the output is truncated.
I've found that this line of code should make it consider the new tokens to be timestamp_tokens, which causes it to truncate the result. https://github.com/huggingface/transformers/blob/d806fa3e92289876e01ab19c9e19e9264ea1c1a1/src/transformers/models/whisper/generation_whisper.py#L1759
Expected behavior
The generate function should output the full transcription.