ludwig-ai / ludwig

Low-code framework for building custom LLMs, neural networks, and other AI models
http://ludwig.ai
Apache License 2.0
11.11k stars 1.19k forks source link

Add support for RSLoRA and DoRA #3948

Closed arnavgarg1 closed 7 months ago

arnavgarg1 commented 7 months ago

Both of these options are configurable using the regular lora adapter, but have different duties! They can be enabled together but have strengths in different circumstances.

Rank Stabilized LoRA (RSLoRA)

When set to True, we use Rank-Stabilized LoRA which sets the adapter scaling factor to lora_alpha/math.sqrt(r), since it was proven to work better. Otherwise, it will use the original default value of lora_alpha/r.

In equation form:

In particular, this is useful when using larger ranks since it prevents the gradient from collapsing as rank increases, which may result in higher ranks actually leading to better performance (not true by default today and in the original LoRA paper). Paper: https://arxiv.org/pdf/2312.03732.pdf.

Screenshot 2024-02-29 at 1 24 53 AM

Weight-Decomposed Low-Rank Adaptation (DoRA)

This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than pure LoRA. For more information, see https://arxiv.org/abs/2402.09353.

Screenshot 2024-02-29 at 1 01 46 AM

In practice, this is what the difference looks like when a model is loaded in with regular LoRA vs DoRA. In particular, note the introduction of a new lora_magnitude_vector learnable layer of size rank when DoRA is enabled.

Tiny-Random Llama with LoRA

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 8, padding_idx=0)
        (layers): ModuleList(
          (0): LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=8, out_features=8, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(in_features=8, out_features=8, bias=False)
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=8, out_features=8, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (o_proj): Linear(in_features=8, out_features=8, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=8, out_features=32, bias=False)
              (up_proj): Linear(in_features=8, out_features=32, bias=False)
              (down_proj): Linear(in_features=32, out_features=8, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
      )
      (lm_head): Linear(in_features=8, out_features=32000, bias=False)
    )
  )
)

Tiny-Random Llama with DoRA

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 8, padding_idx=0)
        (layers): ModuleList(
          (0): LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=8, out_features=8, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8])
              )
              (k_proj): Linear(in_features=8, out_features=8, bias=False)
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=8, out_features=8, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8])
              )
              (o_proj): Linear(in_features=8, out_features=8, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=8, out_features=32, bias=False)
              (up_proj): Linear(in_features=8, out_features=32, bias=False)
              (down_proj): Linear(in_features=32, out_features=8, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
      )
      (lm_head): Linear(in_features=8, out_features=32000, bias=False)
    )
  )
)
github-actions[bot] commented 7 months ago

Unit Test Results

  4 files  ±       0    4 suites  ±0   10m 0s :stopwatch: - 17m 48s 12 tests  - 2 972    9 :heavy_check_mark:  - 2 962    3 :zzz:  - 9  0 :x:  - 1  40 runs   - 2 960  28 :heavy_check_mark:  - 2 953  12 :zzz:  - 6  0 :x:  - 1 

Results for commit f2945d43. ± Comparison against base commit 867d699a.

:recycle: This comment has been updated with latest results.