sabetAI / BLoRA

batched loras
326 stars 15 forks source link

Benchmark using pytorch - speed up lora operation on batch #1

Open yacineMTB opened 1 year ago

yacineMTB commented 1 year ago

Bear with me.. I'm learning!

I thought that the python for loop would slow things down, instead you could batch the operation. This is still WIP, I'm trying to wrap my head around this.

https://github.com/sabetAI/bloras/blob/21839a61b883b1398b2418a7992f1c1175506874/blora_utils.py/#L1884-L1891

Each one of these, as far as I can tell, is equivalent

image

I had a hunch that it would be faster to do this in one go, so I fashioned a small benchmark. I ran this on my system, RTX 3090

Loop Mean: 0.00021034269332885743
Batched Mean: 7.178418636322021e-05
Loop Median: 0.0001556873321533203
Batched Median: 6.985664367675781e-05
loop_sum: 2.103426933288574
batched_sum: 0.7178418636322021

Please check that everything is equivalent, I'm quite new at this!

sidnb13 commented 12 months ago

Been thinking about how to further speed this up, I ended up coming up with a solution very similar to yours. Is there any way of parallelizing across the number of lora weights n, maybe by writing a custom CUDA kernel?

sabetAI commented 12 months ago

Yes @sidnb13 you can stack the LoRAs into a single tensor, and broadcast slices over their corresponding batch elements.

sidnb13 commented 11 months ago

@sabetAI came across this function after asking for help in the PyG community: https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul. It effectively vectorizes across the number of adapters. Some quick testing shows it's actually much slower than the looped approach, but maybe someone else can give it a go.

sabetAI commented 11 months ago

@sidnb13 nice try! segment_matmul is the perfect function for a blora op, kernel's probably not optimized though. I also attempted parallelizing the blora op through matrix reshapes and stacking, seemed to take longer than a simple loop unfortunately.