Closed nivibilla closed 3 months ago
I think it's technically possible, since GaLore is a kind of optimizer. But I doubt the resulting model performance after double low-rank decomposition. Here is my sample code:
import torch
import torch.nn as nn
from datasets import load_dataset
from galore_torch import GaLoreAdamW8bit
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers import AutoModelForCausalLM as ModelImp
from transformers import PreTrainedModel as ModelCls
from transformers import Trainer, TrainingArguments, get_cosine_schedule_with_warmup
def main():
model_path = "Llama-7B"
model: ModelCls = ModelImp.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
use_cache=False,
)
peft_config = LoraConfig(
r=64,
lora_alpha=8,
lora_dropout=0.0,
inference_mode=False,
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
model: PeftModel = get_peft_model(model, peft_config)
print(model)
args = TrainingArguments(
"Llama-7B-GaLore",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
max_steps=100,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
save_total_limit=1,
gradient_checkpointing=True,
logging_steps=1,
eval_steps=1,
save_steps=1,
log_level="detail",
)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
ds = load_dataset("parquet", data_files="tokens.parquet", cache_dir="Cache")
ds = ds["train"].train_test_split(16)
trainer = Trainer(
model,
args,
train_dataset=ds["train"],
eval_dataset=ds["test"],
optimizers=load_galore_optimizer(model, ["lora"]),
)
trainer.train()
trainer.save_model()
def load_galore_optimizer(model: ModelCls, target_modules_list=["attn", "mlp"]):
galore_params = []
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
print(module_name)
galore_params.append(module.weight)
id_galore_params = {id(p) for p in galore_params}
regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
param_groups = [
dict(params=regular_params),
dict(
params=galore_params,
rank=1024,
update_proj_gap=500,
scale=0.25,
proj_type="std",
),
]
optimizer = GaLoreAdamW8bit(param_groups, lr=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, 10, 90)
return optimizer, scheduler
if __name__ == "__main__":
main()
Sorry for delay in response. Thank you I will give it a try!
is it possible with gemma? and ["attn", "mlp" ] Could you explain how to find the optimal ones??
thank you so much.
@NickyDark1 It works with all of the models basically. attn
and mlp
indicate the self-attention blocks and MLP blocks, respectively. The self-attention block usually contains fewer parameters, while the MLP block has more. For full-finetuning, choose attn
and mlp
to update with abundance, because they have the largest gradients to compute. You can choose to optimize only attn
or mlp
to avoid overfitting if necessary.
Hi,
Sorry if this is stupid question but, is it possible to use the 8bit galore optimiser in combination with LoRA adapters?
Thanks