shashikg / WhisperS2T

An Optimized Speech-to-Text Pipeline for the Whisper Model Supporting Multiple Inference Engine
MIT License
315 stars 32 forks source link

Add Lora Dynamic switching for inference #71

Open Jeevi10 opened 3 months ago

Jeevi10 commented 3 months ago

Dynamic LoRA (Low-Rank Adaptation) switching functionality, allowing users to change LoRA models on-the-fly during inference without reloading the entire model.

StephennFernandes commented 3 months ago

@Jeevi10 hey can you link some resources on Dynamic LoRA specifically for whisper , mainly how this type of inference works and how to use LoRA to finetune whisper

Jeevi10 commented 2 months ago

@StephennFernandes Thank you for your reply.

Resources for dynamic lora:

https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec#run-bart-with-lora https://github.com/cccntu/minLoRA/tree/main https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#run-llama-with-several-lora-checkpoints https://github.com/S-LoRA/S-LoRA

I have provided some example repos where I got the idea from. Unfortunately I don't see any specific implementations for whisper directly.

Just to provide you an idea I created running example using huggingface transformers and peft,

import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from datasets import load_dataset from peft import PeftModel import torch_tensorrt

device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

base_model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, )

base_model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

peft_model_id = "path to checkpoint adapter 1" peft_model_id_2 = "path to checkpoint adapter2" model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name='adapter 1', device_map="auto") model.load_adapter(peft_model_id_2, adapter_name='adapter 2')

Enable static cache and compile the forward pass

model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

model = torch.compile(model, backend="torch_tensorrt",dynamic=False)

pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch.float16, device=f"cuda:{0}", model_kwargs={"attn_implementation": "flash_attention_2"}, )

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

def iterate_data(dataset): for i, item in enumerate(dataset): yield item["audio"]

set the batch size in accordance to your device

BATCH_SIZE = 16

predictions = []

run streamed inference adapter 1

for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE): predictions.append(out["text"])

print(predictions)

pipe.model.set_adapter('adapter 2')

run streamed inference adapter 2

for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE): predictions.append(out["text"])

print(predictions)

Whisper Finetuning with lora

https://github.com/Vaibhavs10/fast-whisper-finetuning

StephennFernandes commented 2 months ago

@Jeevi10 thanks for the heads up.

I'll try to write an update for WhisperS2T for being able to use dynamic adapters

Jeevi10 commented 2 months ago

@StephennFernandes I am looking forward to it.