google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

LoRA support in transformers #96

Open tanmayshishodia opened 5 months ago

tanmayshishodia commented 5 months ago

Hi I went through the praxis.transformers.StackedTransformer layer and I don't see any support for LoRA.

That said I was wondering if there was a way to add a set of new LoRA weights to an already existing paxConfig model. If you could give any example that would be great. The tutorials don't cover any such use case where we can update the model layers later.

tanmayshishodia commented 5 months ago

https://github.com/google/paxml/pull/83

Refer above PR if you are looking to implement LoRA in PaxML for now.