aiola-lab / whisper-medusa

Whisper with Medusa heads
MIT License
800 stars 49 forks source link

Sample code not working #9

Closed milsun closed 3 months ago

milsun commented 3 months ago

model-00001-of-00002.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.99G/4.99G [00:22<00:00, 12.4MB/s] model-00002-of-00002.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.25G/1.25G [01:31<00:00, 13.7MB/s] Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:54<00:00, 57.35s/it] config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.99k/1.99k [00:00<00:00, 10.0MB/s] model.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.17G/6.17G [07:33<00:00, 13.6MB/s] generation_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.29k/4.29k [00:00<00:00, 16.9MB/s] Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.87s/it] generation_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.38k/4.38k [00:00<00:00, 5.19MB/s] preprocessor_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 339/339 [00:00<00:00, 4.85MB/s] tokenizer_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 283k/283k [00:00<00:00, 1.53MB/s] vocab.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 2.81MB/s] merges.txt: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 494k/494k [00:00<00:00, 1.19MB/s] normalizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52.7k/52.7k [00:00<00:00, 54.9MB/s] added_tokens.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34.6k/34.6k [00:00<00:00, 38.1MB/s] special_tokens_map.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.19k/2.19k [00:00<00:00, 8.77MB/s] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Traceback (most recent call last): File "/Users/milan/Desktop/ai/whisper-medusa/test.py", line 21, in input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/milan/miniconda3/envs/whisper-medusa/lib/python3.11/site-packages/transformers/models/whisper/processing_whisper.py", line 70, in call inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/milan/miniconda3/envs/whisper-medusa/lib/python3.11/site-packages/transformers/models/whisper/feature_extraction_whisper.py", line 270, in call input_features = padded_inputs.get("input_features").transpose(2, 0, 1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axes don't match array

rossudev commented 3 months ago

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. I got the same error. But that wasn't the issue. I fixed by ensuring the sample wav file was actually 16000 hz. Another error regarding ValueError: axes don't match array had to do with the mono/stereo aspect of the sample. Added an averaging step to get the sample down to mono. Works now for me.

...
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_speech, sr = torchaudio.load(path_to_audio)
print(input_speech.shape)
if input_speech.shape[0] > 1:  # If stereo, average the channels
    input_speech = input_speech.mean(dim=0, keepdim=True)
    input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)

input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
...
AvivSham commented 3 months ago

Hi @milsun, Thank you for your interest in our work! We suspect your audio sample is stereo instead of mono. If that is the case you can average across this axis as @rossudev suggested (thanks!). If your sample is mono please share it so we can investigate it in depth.

AvivNavon commented 3 months ago

Thanks @rossudev, we added the channel average to the example usage code snippet