Vaibhavs10 / insanely-fast-whisper

Apache License 2.0
7.79k stars 547 forks source link

Too high vRAM usage #191

Open Majdoddin opened 9 months ago

Majdoddin commented 9 months ago

On virtualized instances from Vast.ai (without CLI): On a A100 40GB it consumes all the vRAM for a 10s audio, run time ~3s, On a A100 80GB it consumes some 51GB of vRAM for a 10s audio, runtime ~3.5s.

vRAM usage is too much. Am I missing some settings? (is_flash_attn_2_available() returns False, if it matters)

See also #192 My Code:

import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import time
import numpy as np
import torch
from subprocess import CalledProcessError, run

def load_my_audio(file: str, sr: int = SAMPLE_RATE):
    #from whisper code base. loads audio into a numpy array   
    ...

device = "cuda:0" if torch.cuda.is_available() else 'cpu'

pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-large-v3", # select checkpoint from https://huggingface.co/openai/whisper-large-v3#model-details
    torch_dtype=torch.float16,
    device=device, # or mps for Mac devices
    model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
np_array = load_my_audio("10s.wav")

generate_kwargs = {
    #'num_beams': 1,
    "language": 'it',
    "do_sample": False,
    #"temperature": 0,
    #"repetition_penalty": 3.0,
    # "condition_on_previous_text": True,
    #"task": "transcribe"
}

startt = time.time()
outputs = pipe(
    # "10s.wav",
    np_array,
    chunk_length_s=30,
    batch_size=1,
    return_timestamps=False,
    generate_kwargs = generate_kwargs
)
endt = time.time()
duration = endt - startt
print(duration)