Open Jeevi10 opened 3 months ago
@Jeevi10 hey can you link some resources on Dynamic LoRA specifically for whisper , mainly how this type of inference works and how to use LoRA to finetune whisper
@StephennFernandes Thank you for your reply.
https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec#run-bart-with-lora https://github.com/cccntu/minLoRA/tree/main https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#run-llama-with-several-lora-checkpoints https://github.com/S-LoRA/S-LoRA
import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from datasets import load_dataset from peft import PeftModel import torch_tensorrt
device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
base_model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, )
base_model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
peft_model_id = "path to checkpoint adapter 1" peft_model_id_2 = "path to checkpoint adapter2" model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name='adapter 1', device_map="auto") model.load_adapter(peft_model_id_2, adapter_name='adapter 2')
model.generation_config.cache_implementation = "static"
model = torch.compile(model, backend="torch_tensorrt",dynamic=False)
pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch.float16, device=f"cuda:{0}", model_kwargs={"attn_implementation": "flash_attention_2"}, )
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
def iterate_data(dataset): for i, item in enumerate(dataset): yield item["audio"]
BATCH_SIZE = 16
predictions = []
for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE): predictions.append(out["text"])
print(predictions)
pipe.model.set_adapter('adapter 2')
for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE): predictions.append(out["text"])
print(predictions)
@Jeevi10 thanks for the heads up.
I'll try to write an update for WhisperS2T for being able to use dynamic adapters
@StephennFernandes I am looking forward to it.
Dynamic LoRA (Low-Rank Adaptation) switching functionality, allowing users to change LoRA models on-the-fly during inference without reloading the entire model.