EricLBuehler / candle-lora

Low rank adaptation (LoRA) for Candle.
MIT License
130 stars 14 forks source link
candle fine-tuning rust

candle-lora

MIT License Continuous integration Documentation

LoRA (low rank adaptation) implemented in Rust for use with Candle. This technique interchanges the fully-trainable layers of the model with new, LoRA layers. These LoRA layers act as a wrapper over the original layers, but freeze the original layers. Because they contain fewer trainable parameters, LoRA allows for more efficient fine-tuning.

However, using a fine-tuned LoRA model for inference will have a negative impact on performance. This is because the original layer must still be used to calculate the outputs. However, for a LoRA model, an algorithm known as weight merging nullifies the added cost of using the fine-tuned LoRA model by merging the LoRA and original weights. Weights may also be unmerged.

Please see our recent paper X-LoRA. We introduce a MoE inspired method to densely gate LoRA adapters powered by a model self-reflection forward pass. For inference, we have created mistral.rs, which is written in Rust and enables inference of X-LoRA and other models including quantized.

Get started

1) To install, run the following:

cargo add --git https://github.com/EricLBuehler/candle-lora.git candle-lora candle-lora-macro

2) To allow candle-lora to swap layers, do the following for each model struct

Features

Conversion Ergonomics

candle-lora-macro makes using candle-lora as simple as adding 2 macros to your model structs and calling a method!

It is inspired by the simplicity of the Python peft library's get_peft_model method. Together, these macros mean that candle-lora can be added to any candle model with minimal code changes!

LoRA transformers

See transformers from Candle which have LoRA integrated here. Currently, the following transformers have been converted:

To use a LoRA transformer, simply replace the model from candle-transformers with its counterpart in candle-lora-transformers!

Saving and loading

candle_lora supports retrieving weights for LoRA adapters via the get_tensors method, defined automatically in #[auto_layer_convert]. This function is meant to be used with candle_core::safetensors::save(). To load, simply load the VarBuilder and pass that to get_lora_model.

candle_lora's weight naming is not compatible with peft yet.

Resources

candle-lora's LoRA conversion implementations are based on HuggingFace's peft library. See the original paper here, as well as Microsoft's implementation.