keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
771 stars 233 forks source link

Add a keras.io example for LoRA #1051

Closed abheesht17 closed 1 year ago

abheesht17 commented 1 year ago

Paper: https://arxiv.org/abs/2106.09685

LoRA is a finetuning method which freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.

We will finetune GPT-2 (can use the same Reddit dataset used in this example: https://keras.io/examples/generative/gpt2_text_generation_with_kerasnlp/#finetune-on-reddit-dataset).

Post adding the example, we can add it as a feature in KerasNLP.

def get_lora_model(model, rank, alpha, ...):
    ...

CC: @mattdangerw

jbischof commented 1 year ago

I need to read up on the technique more, but wouldn't

model.get_lora_model(...)

be a lot easier to implement than a God method to modify any model?

abheesht17 commented 1 year ago

@jbischof, you are right! It won't be easy to have a one-for-all method, especially because layer names differ between models. The keras.io example I am working on will give us a good idea on how to proceed with API changes. ETA for example: 1-2 weeks :)

jbischof commented 1 year ago

Awesome! One bit of prior art is our get_feature_extractor method in KerasCV (link). Short story is the many CV tasks need to extract specific layers from the backbone model rather than consuming the output directly (example).

mattdangerw commented 1 year ago

I think one question we are going to need to answer is if we can do this in a nice stateless way. E.g. model.get_lora_model() vs an actual mutator model.add_lora_weights(). We definitely want to avoiding loading two copies of our model weights into memory (as the whole point of lora is to lower memory recs).

Currently implementations I have seen just mutate in adapter layers, but if we could do something like functionally retrace the model so weights are shared but the original model is unchanged that would be very cool. Might be tricky with the amount we delegate to our transformer blocks.