huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.54k stars 280 forks source link

Running medium.en model on Jetson Xavier #30

Closed juansebashr closed 9 months ago

juansebashr commented 10 months ago

Hi! I was running the Colab code into a Jetson Xavier platform with CUDA 10.8 and a custom compiled torch 1.8, we can't update Jetpack/CUDA right now due to other limitations but I managed to run the model on GPU, but I got this translations

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.26k/2.26k [00:00<00:00, 3.27MB/s]
Downloading model.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 789M/789M [01:16<00:00, 10.2MB/s]
Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.39k/1.39k [00:00<00:00, 1.12MB/s]
  0%|                                                                                                                                                                    | 0/73 [00:00<?, ?it/s]['AN OKout.. S s카� I s мен reckless gly Hretsor�� enosaurs.']
  1%|██▏                                                                                                                                                         | 1/73 [00:02<03:23,  2.83s/it][' sharing SAN. OKout..el었gram 14ire en mas.']
  3%|████▎                                                                                                                                                       | 2/73 [00:04<02:53,  2.44s/it][' difspakacally yourcompl line I s ad,if describe gче dramaticíp�ak,ict privamin O fra g That douurityys Cabor sщ.']
  4%|██████▍                                                                                                                                                     | 3/73 [00:07<02:46,  2.38s/it][' difist TaiwanSOno toward únicoio Horrorel over S Anaderavort, g G tutorial lroink maybe I drin T And rebell.']
  5%|████████▌                                                                                                                                                   | 4/73 [00:10<02:56,  2.56s/it][' Jetzt chel symbol H a certainly I age تو g Germany nickname, gượcel psychiatric apprentice Hith Yourith aand pairing Barcelona.AN.アk Next SCPel ropes presentedallyies most l s give YeacAN. You beenouöor vida enаст.odyAN.ruct academic love sell en vo.. a treball ہے re sely,�ch exam, very a consequently can l a Ehkl, ver pr,']
  7%|██████████▋                                                                                                                                                 | 5/73 [00:13<03:25,  3.02s/it][' had S neck="osakorerm "ittle accelerated going new sacrifices H,ittle основ l ble.']
  8%|████████████▊                                                                                                                                               | 6/73 [00:16<03:12,  2.87s/it]['� sentlyéd I getting,AN. OKout.. Russiaif helpful elastic HTML.']
 10%|██████████████▉                                                                                                                                             | 7/73 [00:18<02:48,  2.56s/it][' aproximadamentechspak S I a polximoramas gaks l getting S faz unings.']
 11%|█████████████████                                                                                                                                           | 8/73 [00:20<02:42,  2.50s/it][' placeosivingоia, out H Ixt clients, countries gد.']
 12%|███████████████████▏                                                                                                                                        | 9/73 [00:22<02:29,  2.34s/it][' dif that saladys unseen s Dongacist manurn sur баз getting g partlyiqu great symbol,�� aOver Everyoneor s long Jun g sulakac l s gen� I getting, dressed大家都ill s�� Ear here Worldus Police.']
 14%|█████████████████████▏                                                                                                                                     | 10/73 [00:25<02:40,  2.55s/it]['� s wond, g s can Max got accurate obswh Oس re s examples base.']
 15%|███████████████████████▎                                                                                                                                   | 11/73 [00:28<02:32,  2.46s/it]['oundings,ch Sو Aut reANП Michaelos So Leaderac ahow may 안돼 s gift pros I pr, g lidif Liebe meditation l •ät going controversac many怕 I раз']
 16%|█████████████████████████▍                                                                                                                                 | 12/73 [00:30<02:35,  2.54s/it][" 되,idi, enell over prob bitase '."]
 18%|███████████████████████████▌                                                                                                                               | 13/73 [00:33<02:27,  2.47s/it]['AN. OKout..istły en beginning,oschist secretossor onlyки saters engine I recy.']
 19%|█████████████████████████████▋                                                                                                                             | 14/73 [00:35<02:18,  2.35s/it][' con typically OKout.., for.A.']
 21%|███████████████████████████████▊                                                                                                                           | 15/73 [00:37<02:18,  2.38s/it][' currentlyadck desert opt I éléments, s Ukraine Useilanist project a oaby、 a Dark, к partyert B000 forgivefer aations degradation похож.']
 22%|█████████████████████████████████▉                                                                                                                         | 16/73 [00:40<02:23,  2.53s/it][" difistRA gRAos ', Mur 장bero, J then muror conseguirang s 15jin sdoor g theninate s 솔직if most plan."]
 23%|████████████████████████████████████                                                                                                                       | 17/73 [00:42<02:16,  2.43s/it][' Tome organizations a vara Vcess T collaborativeor isies.odyif yourch hasta near g swing s invade哦ith achieith also out then man twelve.']
 25%|██████████████████████████████████████▏                                                                                                                    | 18/73 [00:45<02:26,  2.66s/it][' C maybe imag then man creators,ink�ino surital g esc sley.']
 26%|████████████████████████████████████████▎                                                                                                                  | 19/73 [00:48<02:15,  2.50s/it][' C Direist surely aBy g really / Hcketosad.']
 27%|██████████████████████████████████████████▍                                                                                                                | 20/73 [00:50<02:05,  2.36s/it][' T ostzione��ittle rempeciallyor style tw conf,inkch people So aboutcause.']
 29%|████████████████████████████████████████████▌                                                                                                              | 21/73 [00:52<01:56,  2.24s/it][' T la下or Go really optionsor shoes,inkch people So It.']
 30%|██████████████████████████████████████████████▋                                                                                                            | 22/73 [00:53<01:47,  2.11s/it][' dif copied g 나�fore contributed, contrast somet Dire.']
 32%|████████████████████████████████████████████████▊                                                                                                          | 23/73 [00:56<01:49,  2.20s/it][' Tiritch уri overş��, �ert B000.']
 33%|██████████████████████████████████████████████████▉                                                                                                        | 24/73 [00:58<01:49,  2.23s/it][' dif уri overallyort.']
 34%|█████████████████████████████████████████████████████                                                                                                      | 25/73 [01:00<01:42,  2.14s/it]['oundings,ferel figch G about l newwordityithadyith und Cel knew, hab By Hcause genacro Clarkakor sit yeort nuclear.']
 36%|███████████████████████████████████████████████████████▏                                                                                                   | 26/73 [01:03<01:50,  2.35s/it]['row30,lyien� First culture.']
 37%|█████████████████████████████████████████████████████████▎                                                                                                 | 27/73 [01:05<01:40,  2.18s/it]['謝 Sther optionsting?']
 38%|███████████████████████████████████████████████████████████▍                                                                                               | 28/73 [01:07<01:39,  2.22s/it][' 무서ert B000 l samanке.']
 40%|█████████████████████████████████████████████████████████████▌                                                                                             | 29/73 [01:09<01:32,  2.10s/it]['謝 Sac?']
 41%|███████████████████████████████████████████████████████████████▋                                                                                           | 30/73 [01:10<01:22,  1.93s/it][' Cariке S l s gen기 h Georget, s Make lort undwordity, contrast simple culture.']
 42%|█████████████████████████████████████████████████████████████████▊                                                                                         | 31/73 [01:12<01:24,  2.00s/it][' simple Some Pascal.']
 44%|███████████████████████████████████████████████████████████████████▉                 

etc...

Anyone has some idea on what it's happening? And where should I start looking to fixing it?

Thank you soo much :D

juansebashr commented 10 months ago

FYI, I was running a comparison between the base model of OpenAI Whisper and that model works out just fine

sanchit-gandhi commented 10 months ago

Hey @juansebashr - could you share the end-to-end script you're using for this benchmark? At a first glance I would check that the tokenizer you're using is the correct one for this model. E.g. if benchmarking distil-whisper/distil-medium.en, that you are loading the tokenizer that corresponds to this checkpoint

juansebashr commented 10 months ago

Of course @sanchit-gandhi , here is the script, is the same as the colab

# -*- coding: utf-8 -*-
"""Distil_Whisper_Benchmark.ipynb

## Benchmarking

Great, now that we've understood why Distil-Whisper should be faster in theory, let's see if it holds true in practice.

To begin with, we install `transformers`, `accelerate`, and `datasets`.

In this notebook, we use a A100 GPU that is available through a Colab pro subscription, as this is the device we used for benchmarking in the [Distil-Whisper paper](https://huggingface.co/papers/2311.00430). Other GPUs will most likely lead to different speed-ups, but they should be in the same ballpark range:
"""

#!pip install --upgrade --quiet transformers accelerate datasets

"""In addition, we will make use of [Flash Attention 2](), as it saves
a lot of memory and speeds up large matmul operations.
"""

#!pip install --quiet flash-attn --no-build-isolation

"""To begin with, let's load the dataset that we will use for benchmarking. We'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean dataset. This amounts to ~9MB of data, so it's very lightweight and quick to download on device:"""

from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

"""We start by benchmarking [Whisper large-v2](https://huggingface.co/openai/whisper-large-v2) to get our baseline number. We'll load the model in `float16` precision and make sure that loading time takes as little time as possible by passing `low_cpu_mem_usage=True`. In addition, we want to make sure that the model is loaded in [`safetensors`](https://github.com/huggingface/safetensors) format by passing `use_safetensors=True`:"""

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch

device = "cuda:0"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-base"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=False
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

"""Great! For the benchmark, we will only measure the generation time (encoder + decoder), so let's write a short helper function that measures this step:"""

import time

def generate_with_time(model, inputs):
    start_time = time.time()
    outputs = model.generate(**inputs)
    generation_time = time.time() - start_time
    return outputs, generation_time

"""This function will return both the decoded tokens as well as the time
it took to run the model.

We now iterate over the audio samples and sum up the generation time.
"""

from tqdm import tqdm

all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

"""Alright! In total it took roughly 63 seconds to transcribe 73 audio samples.

Next, let's see how much time it takes with [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v2):
"""

model_id = "distil-whisper/distil-medium.en"

distil_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=False
)
distil_model = distil_model.to(device)

"""We run the same benchmarking loop:"""

all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(distil_model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

"""Only 10 seconds - that amounts to a 6x speed-up!

## Memory

In addition to being significantly faster, Distil-Whisper also has fewer parameters. Let's have a look at how many fewer exactly.
"""

distil_model.num_parameters() / model.num_parameters() * 100

"""Distil-Whisper is 49% of the size of Whisper. Note that this ratio is much lower if we would just compare the size of the decoder:"""

distil_model.model.decoder.num_parameters() / model.model.decoder.num_parameters() * 100

"""As expected the decoder is much smaller. One might have guessed that it should even be less, around 2/32 (or 6%), but we can't forget that the decoder has a very large word embedding that requires a lot of parameters.

## Next steps

Hopefully this notebook shed some light on the motivation behind Distil-Whisper! For now, we've measured Distil-Whisper mainly on GPU, but are now actively looking into collaborating to release code how to effectively accelerate Distil-Whisper on CPU as well. Updates will be posted on the Distil-Whisper [repository](https://github.com/huggingface/distil-whisper).

Another key application of Distil-Whisper is *speculative decoding*. In speculative decoding, we can use Distil-Whisper as an *assitant model* to Whisper-large-v2 to reach a speed-up of 2x without **any** loss in performance. More on that in a follow-up notebook that will come out soon!
"""
sanchit-gandhi commented 10 months ago

Yes indeed it's the tokenizer that's the issue - the checkpoint openai/whisper-base uses a different tokeniser to distil-whisper/distil-medium.en. You need to load the tokenizer for distil-whisper/distil-medium.en to decode the generated ids from the distilled model. See the diff below, the line in green is the additional line you need to add:

# -*- coding: utf-8 -*-
"""Distil_Whisper_Benchmark.ipynb

## Benchmarking

Great, now that we've understood why Distil-Whisper should be faster in theory, let's see if it holds true in practice.

To begin with, we install `transformers`, `accelerate`, and `datasets`.

In this notebook, we use a A100 GPU that is available through a Colab pro subscription, as this is the device we used for benchmarking in the [Distil-Whisper paper](https://huggingface.co/papers/2311.00430). Other GPUs will most likely lead to different speed-ups, but they should be in the same ballpark range:
"""

#!pip install --upgrade --quiet transformers accelerate datasets

"""In addition, we will make use of [Flash Attention 2](), as it saves
a lot of memory and speeds up large matmul operations.
"""

#!pip install --quiet flash-attn --no-build-isolation

"""To begin with, let's load the dataset that we will use for benchmarking. We'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean dataset. This amounts to ~9MB of data, so it's very lightweight and quick to download on device:"""

from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

"""We start by benchmarking [Whisper large-v2](https://huggingface.co/openai/whisper-large-v2) to get our baseline number. We'll load the model in `float16` precision and make sure that loading time takes as little time as possible by passing `low_cpu_mem_usage=True`. In addition, we want to make sure that the model is loaded in [`safetensors`](https://github.com/huggingface/safetensors) format by passing `use_safetensors=True`:"""

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch

device = "cuda:0"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-base"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=False
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

"""Great! For the benchmark, we will only measure the generation time (encoder + decoder), so let's write a short helper function that measures this step:"""

import time

def generate_with_time(model, inputs):
    start_time = time.time()
    outputs = model.generate(**inputs)
    generation_time = time.time() - start_time
    return outputs, generation_time

"""This function will return both the decoded tokens as well as the time
it took to run the model.

We now iterate over the audio samples and sum up the generation time.
"""

from tqdm import tqdm

all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

"""Alright! In total it took roughly 63 seconds to transcribe 73 audio samples.

Next, let's see how much time it takes with [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v2):
"""

model_id = "distil-whisper/distil-medium.en"

distil_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=False
)
distil_model = distil_model.to(device)

+ processor = AutoProcessor.from_pretrained(model_id)

"""We run the same benchmarking loop:"""

all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(distil_model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

"""Only 10 seconds - that amounts to a 6x speed-up!

## Memory

In addition to being significantly faster, Distil-Whisper also has fewer parameters. Let's have a look at how many fewer exactly.
"""

distil_model.num_parameters() / model.num_parameters() * 100

"""Distil-Whisper is 49% of the size of Whisper. Note that this ratio is much lower if we would just compare the size of the decoder:"""

distil_model.model.decoder.num_parameters() / model.model.decoder.num_parameters() * 100

"""As expected the decoder is much smaller. One might have guessed that it should even be less, around 2/32 (or 6%), but we can't forget that the decoder has a very large word embedding that requires a lot of parameters.

## Next steps

Hopefully this notebook shed some light on the motivation behind Distil-Whisper! For now, we've measured Distil-Whisper mainly on GPU, but are now actively looking into collaborating to release code how to effectively accelerate Distil-Whisper on CPU as well. Updates will be posted on the Distil-Whisper [repository](https://github.com/huggingface/distil-whisper).

Another key application of Distil-Whisper is *speculative decoding*. In speculative decoding, we can use Distil-Whisper as an *assitant model* to Whisper-large-v2 to reach a speed-up of 2x without **any** loss in performance. More on that in a follow-up notebook that will come out soon!
"""
juansebashr commented 10 months ago

Yeah! It worked like a charm, thank you! PD: Maybe should be a good idea to modify the colab or put a warning on the markdown cells

sanchit-gandhi commented 9 months ago

Excellent - glad to hear that @juansebashr! Closing as complete - feel free to open a new issue if you encounter any other problems!