Maximize GPU util by routing inference through multiple LoRAs in the same batch.
Explainer by @yacineMTB.
Trainable parameters for low rank layer adapters are small, and can all be held simultaneously in VRAM. Meaning, you can have the same base model, and change its behavior by swapping LoRAs. Huggingface's PEFT allows swapping adapters over their API. |
But what if you wanted to inference all of your adapters at the same time? The LoRA operation is pretty simple! It creates an output of the same shape as the adapted layer, and then adds them together. That has got to be broadcastable, right? |
It is! If you have a matching number of LoRA adapters, you can fashion an operation to apply on each respective batch. Multiple models, that share the same weights. |
To clone the repository using git
, run:
git clone https://github.com/sabetAI/BLoRA.git
cd BLoRA
Set Up a Virtual Environment (Recommended) and install required packages
pip install -r requirements.txt
from transformers import LlamaForCausalLM, LlamaTokenizer
model_path = "decapoda-research/llama-7b-hf"
model = transformers.LlamaForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)
tokenizer.pad_token = 0
from blora_utils import load_loras
loras = ["jondurbin/airoboros-7b-gpt4-1.2-peft",
"trl-lib/llama-7b-se-rl-peft",
"winddude/wizardLM-LlaMA-LoRA-7B"]
model, lora_map = load_loras(model, loras)
from blora_utils import prepare_batch
inputs = [('Outline a five sentence short story where a character stumbles upon a secret room in their house that contains relics from their future.',
'jondurbin/airoboros-7b-gpt4-1.2-peft'),
('Write a 6 line dialogue between a character and a magical creature that only they can see.',
'trl-lib/llama-7b-se-rl-peft'),
('Describe a four sentence scene where a character discovers a hidden talent that changes their life forever.',
'winddude/wizardLM-LlaMA-LoRA-7B'),
('Sculpt a three verse poem about the feeling of walking through a lush, vibrant garden in full bloom.',
'trl-lib/llama-7b-se-rl-peft'),
('Develop an eight sentence short story about a character who can bring their dreams into reality, but only for a limited time.',
'winddude/wizardLM-LlaMA-LoRA-7B')]
batch = prepare_batch(inputs, tokenizer, model, lora_map)
outputs = []
for out in model.generate(**batch, max_length=200, stream_output=True):
outputs.append(out)
batch_decoded = tokenizer.batch_decode(
torch.cat([out.reshape(-1, 1) for out in outputs], dim=1)
)
print(
"\n\n".join(
[
lora + ":\n" + prompt + "\n" + decoded
for (prompt, lora), decoded in zip(inputs, batch_decoded)
]
)
)
https://github.com/sabetAI/BLoRA/assets/28828395/287b6cce-555e-4626-852c-1ad79672f27e
Thanks to @yacineMTB for reviewing 🙏.