jiaweizzhao / GaLore

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
Apache License 2.0
1.24k stars 131 forks source link

GaLore in HuggingFace #20

Open IamExperimenting opened 3 months ago

IamExperimenting commented 3 months ago

Hi team, very thanks for GaLore. I'm currently using HuggingFace for fine-tuning. Just curious to integrate GaLore with HuggingFace.

It's not an issue, I'm just interested to use GaLore with HuggingFace

@jiaweizzhao

from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
from peft import LoraConfig
import transformers
from trl import SFTTrainer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
)
trainer.train()

should I just replace optim="paged_adamw_8bit" with optim = GaLoreAdamW8bit? can you please provide some sample script ?

PenutChen commented 3 months ago

The official integration of the GaLore optimizer with Hugging Face is underway, as seen in https://github.com/huggingface/transformers/pull/29588.

For now, you can pass the GaLore optimizer in the optimizers argument of the trainer, e.g., https://github.com/jiaweizzhao/GaLore/issues/9#issuecomment-1989719401.

IamExperimenting commented 3 months ago

@PenutChen thanks, I looked into it. I see it is a combination of GaLore and LoRA right.

And I assume, just GaLore is enough to fine-tune all the parameters in the model.

In that case, should I remove LoRA from target_module_list and just leave it to default, Like the below code?

 trainer = Trainer(
        model,
        args,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        optimizers=load_galore_optimizer(model),
    )
penut85420 commented 3 months ago

@IamExperimenting Yes, the LoRA-related part of the code can be completely removed.

IamExperimenting commented 3 months ago

@penut85420 I have remove LoRA related code in the provide example, and pasted it. Can you have a look and confirm whether it is good to go?

import torch
import torch.nn as nn
from datasets import load_dataset
from galore_torch import GaLoreAdamW8bit
from transformers import AutoModelForCausalLM as ModelImp
from transformers import PreTrainedModel as ModelCls
from transformers import Trainer, TrainingArguments, get_cosine_schedule_with_warmup

def main():
    model_path = "Llama-7B"

    model: ModelCls = ModelImp.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )

    print(model)

    args = TrainingArguments(
        "Llama-7B-GaLore",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        max_steps=100,
        evaluation_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=True,
        save_total_limit=1,
        gradient_checkpointing=True,
        logging_steps=1,
        eval_steps=1,
        save_steps=1,
        log_level="detail",
    )

    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    ds = load_dataset("parquet", data_files="tokens.parquet", cache_dir="Cache")
    ds = ds["train"].train_test_split(16)

    trainer = Trainer(
        model,
        args,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        optimizers=load_galore_optimizer(model),
    )
    trainer.train()
    trainer.save_model()

def load_galore_optimizer(model: ModelCls, target_modules_list=["attn", "mlp"]):
    galore_params = []
    for module_name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        if not any(target_key in module_name for target_key in target_modules_list):
            continue
        print(module_name)
        galore_params.append(module.weight)

    id_galore_params = {id(p) for p in galore_params}
    regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]

    param_groups = [
        dict(params=regular_params),
        dict(
            params=galore_params,
            rank=1024,
            update_proj_gap=500,
            scale=0.25,
            proj_type="std",
        ),
    ]

    optimizer = GaLoreAdamW8bit(param_groups, lr=0.01)
    scheduler = get_cosine_schedule_with_warmup(optimizer, 10, 90)

    return optimizer, scheduler

if __name__ == "__main__":
    main()
penut85420 commented 3 months ago

@IamExperimenting you're basically right, just two things to note:

  1. This is not the layerwise-galore implementation, which is too complex to demo in a few lines.
  2. The learning rate of optimizer and step arguments of the learning rate schedule are hard-coded, you should set them for yourself.
optimizer = GaLoreAdamW8bit(param_groups, lr=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, 10, 90)
IamExperimenting commented 3 months ago

@penut85420 oh.. so, I won't be able to completely levearge GaLore

Basically, I want to fine-tune Mistral 7B model(fine-tune all the parameters in the model) with my domain data.

Sorry for so many question.

IamExperimenting commented 3 months ago

@penut85420 could you please help me, I wanted to fine-tune full parameters in the model for (mistral, gemma).

geronimi73 commented 3 months ago

@IamExperimenting Galore was just merged into transformers https://huggingface.co/docs/transformers/main/en/trainer#galore

IamExperimenting commented 3 months ago

@geronimi73 @penut85420 can I assume that the below code will fine-tune full parameter in the model.

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw_layerwise",
    optim_target_modules=["attn", "mlp"]
)
PenutChen commented 3 months ago

@IamExperimenting you can use model.num_parameters() to check the number of trainable parameters:

n_params = model.num_parameters(only_trainable=True)
print(f"Trainable Parameters: {n_params:,}")

For the Mistral 7B model, it will print Trainable Parameters: 7,241,732,096 for full-parameter tuning.

PenutChen commented 3 months ago

BTW, GaLore in Hugging Face Transformers has been formally released 🎉 https://huggingface.co/blog/galore

geronimi73 commented 3 months ago

@geronimi73 @penut85420 can I assume that the below code will fine-tune full parameter in the model.

i suggest you start with the example in the HF docs and see if that works and adapt it to your needs