Open yacineMTB opened 1 year ago
Been thinking about how to further speed this up, I ended up coming up with a solution very similar to yours. Is there any way of parallelizing across the number of lora weights n
, maybe by writing a custom CUDA kernel?
Yes @sidnb13 you can stack the LoRAs into a single tensor, and broadcast slices over their corresponding batch elements.
@sabetAI came across this function after asking for help in the PyG community: https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul. It effectively vectorizes across the number of adapters. Some quick testing shows it's actually much slower than the looped approach, but maybe someone else can give it a go.
@sidnb13 nice try! segment_matmul
is the perfect function for a blora op, kernel's probably not optimized though. I also attempted parallelizing the blora op through matrix reshapes and stacking, seemed to take longer than a simple loop unfortunately.
Bear with me.. I'm learning!
I thought that the python for loop would slow things down, instead you could batch the operation. This is still WIP, I'm trying to wrap my head around this.
https://github.com/sabetAI/bloras/blob/21839a61b883b1398b2418a7992f1c1175506874/blora_utils.py/#L1884-L1891
Each one of these, as far as I can tell, is equivalent
I had a hunch that it would be faster to do this in one go, so I fashioned a small benchmark. I ran this on my system, RTX 3090
Please check that everything is equivalent, I'm quite new at this!