huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

small uses more memory (but is faster) than medium (ONNX quantized) #53

Closed kmn1024 closed 6 months ago

kmn1024 commented 6 months ago

Setup

CUDA 12.2 GTX 1080 Copied all ONNX quantized models and required config jsons to their required location.

Code

from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from py3nvml.py3nvml import *
from transformers import WhisperTokenizerFast, WhisperFeatureExtractor, pipeline
import torch
import time
import numpy as np

# Initialize NVML
nvmlInit()
# Get the first GPU handle
handle = nvmlDeviceGetHandleByIndex(0)
# Get initial GPU memory usage
info = nvmlDeviceGetMemoryInfo(handle)
initial_gpu_memory = info.used
device = "cuda:0"

# Load models
model_name = '/home/Downloads/asr_models/distil-small-en/onnx_quantized' # Copied all needed files
model = ORTModelForSpeechSeq2Seq.from_pretrained(model_name, export=False, use_safetensors=True)
model.to(device)
tokenizer = WhisperTokenizerFast.from_pretrained(model_name)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
gpu_pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    max_new_tokens=128,
    device=0,
)

# Measurements.
times = []
for _ in range(15):
    start = time.time()
    gpu_pipe(wav_file)
    rtf = (time.time() - start) / (wav_secs)
    times.append(rtf)
times.sort()
times = times[1:-1]
print(np.mean(times))

info = nvmlDeviceGetMemoryInfo(handle)
final_gpu_memory = info.used
gpu_memory_used = (final_gpu_memory - initial_gpu_memory) / 1024 / 1024
print(f"GPU memory used by the code block: {gpu_memory_used} MiB")
nvmlShutdown()

Results

Model RTF Mem Used
distil-medium-en 0.5082902994641076 1878.0 MiB
distil-small-en 0.3903530106270162 2884.0 MiB
kmn1024 commented 6 months ago

Could this be because:

  1. distil small uses alignment heads (https://huggingface.co/distil-whisper/distil-small.en/blob/main/generation_config.json#L2), but not medium (as of this writing; https://huggingface.co/distil-whisper/distil-medium.en/blob/main/generation_config.json)?
  2. distil small uses 4 decoder layers (https://huggingface.co/distil-whisper/distil-small.en/blob/main/config.json#L20), but medium only 2 (https://huggingface.co/distil-whisper/distil-medium.en/blob/main/config.json#L20)?
sanchit-gandhi commented 6 months ago

Hey @kmn1024! Thanks for opening this super interesting issue.

  1. It's unlikely the alignment heads, since these are only used for word-level timestamps (which we don't use in the above example)
  2. It could be related to the number of decoder layers, but distil-medium.en uses 24 encoder layers, as opposed to just 12 in distil-small.en, so I'd expect the memory overhead to come on the encoder side.

We decided not to release VRAM (memory) numbers in our benchmarks, since they're very dependent on hardware, CUDA version and PyTorch version. But we record some of these numbers ourselves. In my provisional benchmark, averaging over 100 samples of the LibriSpeech dataset, I got:

  1. distil-small.en: 1.95GB
  2. distil-medium.en: 2.79GB => so quite convincingly, distil-small.en was lower memory than distil-medium.en

One reason for higher memory could be more decoding steps in distil-small.en vs distil-medium.en, possibly because of a hallucination? This would increase the memory of the k/v cache. It could be a good idea to average over a number of different samples, e.g. as per:

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm

# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-small.en"

# load the model + processor
model =  AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

# load the dataset with streaming mode
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)

def inference(batch):
    # 1. Pre-process the audio data to log-mel spectrogram inputs
    audio = [sample["array"] for sample in batch["audio"]]
    input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
    input_features = input_features.to(device, dtype=torch_dtype)

    # 2. Auto-regressively generate the predicted token ids
    pred_ids = model.generate(input_features, max_new_tokens=128)

    # 3. Decode the token ids to the final transcription
    batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
    batch["reference"] = batch["text"]
    return batch

dataset = dataset.map(function=inference, batched=True, batch_size=16)

all_transcriptions = []
all_references = []

# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
    all_transcriptions.append(result["transcription"])
    all_references.append(result["reference"])

# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
xenova commented 6 months ago

Hi @kmn1024 đź‘‹ I did the conversions to ONNX, so I might have an explanation for this. I believe this is due to the additional outputs nodes, corresponding to the computed attentions. The reason I exported with these outputs is so that users can generate word-level timestamps with these models (and this might not be the case for the previous medium models).

If this is something you will not need, you can do the conversions yourself with Optimum:

optimum-cli export onnx -m distil-whisper/distil-small.en output
kmn1024 commented 6 months ago
Thanks all! I can confirm that converting and quantizing from scratch works. The numbers are now: Model RTF Mem Used
distil-medium-en 0.5082902994641076 1878.0 MiB
distil-small-en 0.3782055584150302 912.6875 MiB

P.S. The Optimum quantization command doesn't work out of the box; had to skip conv nodes as suggested in https://github.com/microsoft/onnxruntime/issues/15888.

sanchit-gandhi commented 6 months ago

Thanks for the great explanation @xenova!