jiaweizzhao / GaLore

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
Apache License 2.0
1.24k stars 131 forks source link

Galore + Lora? #9

Closed nivibilla closed 3 months ago

nivibilla commented 3 months ago

Hi,

Sorry if this is stupid question but, is it possible to use the 8bit galore optimiser in combination with LoRA adapters?

Thanks

PenutChen commented 3 months ago

I think it's technically possible, since GaLore is a kind of optimizer. But I doubt the resulting model performance after double low-rank decomposition. Here is my sample code:

import torch
import torch.nn as nn
from datasets import load_dataset
from galore_torch import GaLoreAdamW8bit
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers import AutoModelForCausalLM as ModelImp
from transformers import PreTrainedModel as ModelCls
from transformers import Trainer, TrainingArguments, get_cosine_schedule_with_warmup

def main():
    model_path = "Llama-7B"

    model: ModelCls = ModelImp.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )

    peft_config = LoraConfig(
        r=64,
        lora_alpha=8,
        lora_dropout=0.0,
        inference_mode=False,
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    )
    model: PeftModel = get_peft_model(model, peft_config)

    print(model)

    args = TrainingArguments(
        "Llama-7B-GaLore",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        max_steps=100,
        evaluation_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=True,
        save_total_limit=1,
        gradient_checkpointing=True,
        logging_steps=1,
        eval_steps=1,
        save_steps=1,
        log_level="detail",
    )

    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    ds = load_dataset("parquet", data_files="tokens.parquet", cache_dir="Cache")
    ds = ds["train"].train_test_split(16)

    trainer = Trainer(
        model,
        args,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        optimizers=load_galore_optimizer(model, ["lora"]),
    )
    trainer.train()
    trainer.save_model()

def load_galore_optimizer(model: ModelCls, target_modules_list=["attn", "mlp"]):
    galore_params = []
    for module_name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        if not any(target_key in module_name for target_key in target_modules_list):
            continue
        print(module_name)
        galore_params.append(module.weight)

    id_galore_params = {id(p) for p in galore_params}
    regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]

    param_groups = [
        dict(params=regular_params),
        dict(
            params=galore_params,
            rank=1024,
            update_proj_gap=500,
            scale=0.25,
            proj_type="std",
        ),
    ]

    optimizer = GaLoreAdamW8bit(param_groups, lr=0.01)
    scheduler = get_cosine_schedule_with_warmup(optimizer, 10, 90)

    return optimizer, scheduler

if __name__ == "__main__":
    main()
nivibilla commented 3 months ago

Sorry for delay in response. Thank you I will give it a try!

NickyDark1 commented 3 months ago

is it possible with gemma? and ["attn", "mlp" ] Could you explain how to find the optimal ones??

thank you so much.

PenutChen commented 3 months ago

@NickyDark1 It works with all of the models basically. attn and mlp indicate the self-attention blocks and MLP blocks, respectively. The self-attention block usually contains fewer parameters, while the MLP block has more. For full-finetuning, choose attn and mlp to update with abundance, because they have the largest gradients to compute. You can choose to optimize only attn or mlp to avoid overfitting if necessary.