davisyoshida / lorax

LoRA for arbitrary JAX models and functions
MIT License
127 stars 4 forks source link

Integration into EasyLM #2

Closed versae closed 1 year ago

versae commented 1 year ago

Hi! Cool project. I wonder how hard it would be to implement an integration of this library into something like @young-geng's EasyLM. That would make using lorax really easy as all the training would be handled by EasyLM.

davisyoshida commented 1 year ago

I've never used EasyLM but the intention is that Lorax should be able to transform arbitrary JAX functions, so it should just work out of the box. If you try it and it doesn't work, let me know what the issue is and I'll try to help.

young-geng commented 1 year ago

As the creator of EasyLM, I imagine that EasyLM might not be the best fit for doing LoRA style training. EasyLM is mainly focused on scaling up the training, where the majority of the heavy lifting is built to support partitioning the model across many accelerators distributed in many hosts. Most of these utils will probably be not very useful in the low resource settings when LoRA is needed.

davisyoshida commented 1 year ago

@young-geng Sounds reasonable. I think there may be some use cases down the line with batched LoRAs on a shared trunk, but unless that picks up popularity it won't be worth it for you to add it.

Since this doesn't seem to be a Lorax issue per se, I'm going to go ahead and close it.

versae commented 1 year ago

Thanks for the discussion! I still see the utility of having LoRA for training large models on single TPU VMs. I'll play around :)