Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.
Apache License 2.0
5.97k stars 516 forks source link

Switch between LoRA adapters? #193

Open totaltube opened 1 year ago

totaltube commented 1 year ago

Is there a way to switch between LoRA adapters? I.e. load several different adapters for different tasks and quickly switch between them for doing different tasks? As it possible with peft library.

awaelchli commented 1 year ago

@totaltube I don't think we support this. Can you give an example? Note that LoRA and Adapter are two different things.

For example, to switch from LoRA to Adapter finetuning, you have to 1. remove the with lora() context manager, 2. replace the lit_llama.model.LLaMA model class with lit_llama.adapter.LLaMA, and 3. change mark_only_lora_as_trainable to mark_mark_only_adapter_as_trainable(model).

lantiga commented 1 year ago

I think the goal there was not to switch from lora to adapter, but rather to switch different lora weights on top of the same base model

totaltube commented 1 year ago

I think the goal there was not to switch from lora to adapter, but rather to switch different lora weights on top of the same base model

Correct. This is example in another library: https://github.com/huggingface/peft/blob/main/examples/multi_adapter_examples/PEFT_Multi_LoRA_Inference.ipynb

The goal is - I can fine-tune model for different tasks independently and quickly switch weights. No need to re-train on extended dataset, which includes another task, for example.

awaelchli commented 1 year ago

I see now. Today, this can be achieved by storing checkpoint files for the different tasks and then loading them into the model with model.load_state_dict(). This is one way to switch between tasks. Another way would be to have a consolidated checkpoint with all weights from the different tasks indexed, and then one could expose a utility (similar to model.set_adapter in your linked example) that selects the right weights. This would work for both methods, LoRA and Adapter. Is this roughly what you were looking for?

lucas-ventura commented 1 year ago

Example class on how to achieve what @awaelchli is saying:

class FinetunedAdapter:
    from lit_llama.adapter import LLaMA, LLaMAConfig

    def __init__(
        self,
        adapter_path: Optional[Path] = None,
        pretrained_path: Optional[Path] = None,
        tokenizer_path: Optional[Path] = None,
        quantize: Optional[str] = None,
    ) -> None:
        if not adapter_path:
            adapter_path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth")
        if not pretrained_path:
            pretrained_path = Path("./checkpoints/lit-llama/7B/lit-llama.pth")
        if not tokenizer_path:
            tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

        assert adapter_path.is_file()
        assert pretrained_path.is_file()
        assert tokenizer_path.is_file()

        self.fabric = L.Fabric(devices=1)
        dtype = (
            torch.bfloat16
            if self.fabric.device.type == "cuda" and torch.cuda.is_bf16_supported()
            else torch.float32
        )

        with EmptyInitOnDevice(
            device=self.fabric.device, dtype=dtype, quantization_mode=quantize
        ):
            self.model = self.LLaMA(self.LLaMAConfig())

        # 1. Load the pretrained weights
        pretrained_checkpoint = lazy_load(pretrained_path)
        self.model.load_state_dict(pretrained_checkpoint, strict=False)

        # 2. Load the fine-tuned adapter weights
        adapter_checkpoint = lazy_load(adapter_path)
        self.model.load_state_dict(adapter_checkpoint, strict=False)

        self.model.eval()
        self.model = self.fabric.setup_module(self.model)

        self.tokenizer = Tokenizer(tokenizer_path)

    def load_adapter(self, adapter_path: Path):
        assert adapter_path.is_file()

        adapter_checkpoint = lazy_load(adapter_path)
        self.model.load_state_dict(adapter_checkpoint, strict=False)

    def generate(
        self,
        instruction: str = "",
        input_text: str = "",
        max_new_tokens: int = 100,
        top_k: int = 200,
        temperature: float = 0.8,
        use_instruction: bool = True,
    ):
        if use_instruction:
            sample = {"instruction": instruction, "input": input_text}
            prompt = generate_prompt(sample)
        else:
            assert input_text, "input_text must be provided if use_prompt is False."
            assert (
                len(instruction) == 0
            ), "instruction must be empty if use_prompt is False."
            prompt = generate_no_prompt(input_text)

        encoded = self.tokenizer.encode(
            prompt, bos=True, eos=False, device=self.model.device
        )

        output = generate(
            self.model,
            idx=encoded,
            max_seq_length=max_new_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            eos_id=self.tokenizer.eos_id,
        )

        output = self.tokenizer.decode(output)
        output = output.split("### Response:")[1].strip()
        return output

And then:

adapter = FinetunedAdapter(adapter_path=adapther_path)
adapter.load_adapter(checkpoint)