huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.69k stars 1.07k forks source link

When I used galore on orpo, the learning rate was set to 8e-6, but the training rate was 0.01 #1638

Closed Minami-su closed 4 days ago

Minami-su commented 1 month ago
trainer = ORPOTrainer(
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],

        #peft_config=peft_config,
        tokenizer=tokenizer,
        args= ORPOConfig(
            max_length=cutoff_len,
            max_prompt_length=cutoff_len//2,
            beta=0.1,
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=0,
            num_train_epochs=num_epochs,
            lr_scheduler_type="cosine",
            learning_rate=8e-6,
            bf16=True,
            logging_steps=10,
            optim = "galore_adamw_8bit_layerwise",
            optim_target_modules=[r".*attn.*", r".*mlp.*"],
            optim_args="rank=1024, update_proj_gap=500, scale=0.25",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=100 if val_set_size > 0 else None,
            save_steps=100,
            output_dir=output_dir,
            save_total_limit=2,
            gradient_checkpointing=True, 
            gradient_checkpointing_kwargs={'use_reentrant':True},
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
            do_train=True,
            remove_unused_columns=False,
        )
    )

Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !
  0%|                                                                                                                                     | 0/495 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
{'loss': 0.3557, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.015678538009524345, 'rewards/rejected': -0.012379011139273643, 'rewards/accuracies': 0.19999998807907104, 'rewards/margins': -0.003299527335911989, 'logps/rejected': -0.12379010766744614, 'logps/chosen': -0.15678536891937256, 'logits/rejected': 0.7921055555343628, 'logits/chosen': 0.791210412979126, 'nll_loss': 0.2719877064228058, 'log_odds_ratio': -0.8374900817871094, 'log_odds_chosen': -0.25091928243637085, 'epoch': 0.06}
{'loss': 0.2634, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.012010233476758003, 'rewards/rejected': -0.009977776557207108, 'rewards/accuracies': 0.29999998211860657, 'rewards/margins': -0.0020324576180428267, 'logps/rejected': -0.09977775812149048, 'logps/chosen': -0.12010233104228973, 'logits/rejected': 0.7489851713180542, 'logits/chosen': 0.7482139468193054, 'nll_loss': 0.1832979917526245, 'log_odds_ratio': -0.8010236620903015, 'log_odds_chosen': -0.16869042813777924, 'epoch': 0.12}
{'loss': 0.2482, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.011346157640218735, 'rewards/rejected': -0.01022450439631939, 'rewards/accuracies': 0.4833333492279053, 'rewards/margins': -0.0011216530110687017, 'logps/rejected': -0.102245032787323, 'logps/chosen': -0.11346157640218735, 'logits/rejected': 0.7105721831321716, 'logits/chosen': 0.7108334898948669, 'nll_loss': 0.17242279648780823, 'log_odds_ratio': -0.7573043704032898, 'log_odds_chosen': -0.08471358567476273, 'epoch': 0.18}
{'loss': 0.2444, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.012975988909602165, 'rewards/rejected': -0.013058923184871674, 'rewards/accuracies': 0.550000011920929, 'rewards/margins': 8.293241262435913e-05, 'logps/rejected': -0.13058921694755554, 'logps/chosen': -0.12975989282131195, 'logits/rejected': 0.6808757781982422, 'logits/chosen': 0.6832461953163147, 'nll_loss': 0.1756206750869751, 'log_odds_ratio': -0.687309741973877, 'log_odds_chosen': 0.04155167192220688, 'epoch': 0.24}
younesbelkada commented 1 month ago

Hi @Minami-su I think this has been fixed on the recent transformers, what transformers version are you using?

Minami-su commented 1 month ago

Hi @Minami-su I think this has been fixed on the recent transformers, what transformers version are you using?

I've tried transformers 4.41.1,transformers 4.42.0.dev0 and they didn't work.

Minami-su commented 1 month ago

I guess because it didn't use trl's orpo.

Minami-su commented 1 month ago

SimPo:https://github.com/princeton-nlp/SimPO/blob/main/scripts/simpo_trainer.py command: CUDA_VISIBLE_DEVICES=1 python ft-SimPo-galore.py \ --base_model Meta-Llama-3-8B-Instruct-hard-iter2 \ --data_path gpt4_1k_hard_SFT_trl_iter3.json \ --output_dir gpt4_1k_hard_SFT_split_3_galore \ --batch_size 6 \ --micro_batch_size 6 \ --num_epochs 3 \ --learning_rate 8e-6 \ --cutoff_len 1400 \ --val_set_size 0 \ --train_on_inputs \ --group_by_length

ft-SimPo-galore.py:


import fire
import torch
import transformers
from datasets import load_dataset
import os
from simpo_trainer import SimPOTrainer
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
)
import signal
import sys
import os
import bitsandbytes as bnb

# torch.autograd.set_detect_anomaly(True)
def find_all_linear_names(model):
    # cls = bnb.nn.Linear8bitLt
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
def train(
    # model/data params
    base_model: str = "",  # the only required argument
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "./lora-alpaca",
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 3e-4,
    cutoff_len: int = 256,
    max_prompt_length: int = 512,
    val_set_size: int = 2000,
    # lora hyperparams
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    beta: float = 2.5,
    gamma: float = 1.4,
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    prompt_template_name: str = "alpaca2",  # The prompt template to use, will default to alpaca.
):
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training Alpaca-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"val_set_size: {val_set_size}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"prompt template: {prompt_template_name}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
    gradient_accumulation_steps = batch_size // micro_batch_size

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    tokenizer = AutoTokenizer.from_pretrained(base_model,trust_remote_code=True)

    if base_model.find("qwen") != -1 or base_model.find("Qwen") != -1:
        tokenizer.add_special_tokens({"bos_token": "<|im_start|>"})
        tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})
        tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})

    elif base_model.find("Llama-3") != -1:
        tokenizer.add_special_tokens({"eos_token": "<|eot_id|>"})
        tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"})
    #tokenizer.padding_side = "left"  # Allow batched inference
    def save_model(signal, frame):
        print("\nSaving the model...")
        model.save_pretrained(output_dir)
        sys.exit(0)

    def process(row):
        # row["chosen"] = row["chosen"]
        # row["rejected"] = row["rejected"]
        row["prompt"] = row["prompt"]
        row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
        # row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
        # row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
        return row

    print(tokenizer.pad_token_id)
    print(tokenizer.pad_token)
    print(tokenizer.bos_token_id)
    print(tokenizer.bos_token)
    print(tokenizer.eos_token_id)
    print(tokenizer.eos_token)
    if data_path.endswith(".json") or data_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_path)
    else:
        data = load_dataset(data_path)
    data = data.map(process,num_proc=1)
    if val_set_size > 0:
        train_val = data["train"].train_test_split(
            test_size=val_set_size, shuffle=True, seed=42
        )
        train_data = train_val["train"].shuffle()
        val_data = train_val["test"].shuffle()
    else:
        train_data = data["train"].shuffle()
        val_data = None
    print(len(train_data))
    model = AutoModelForCausalLM.from_pretrained(base_model,
                                            trust_remote_code=True,
                                             attn_implementation="flash_attention_2",
                                             torch_dtype=torch.bfloat16,
                                             device_map=device_map,
                                             )

    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True
   # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    training_args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=0,#(num_epochs*len(train_data)/batch_size)//10,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            bf16=True,
            logging_steps=10,
            optim = "galore_adamw_8bit_layerwise",
            optim_target_modules=[r".*attn.*", r".*mlp.*"],
            optim_args="rank=1024, update_proj_gap=500, scale=0.25,proj_type=std",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=100 if val_set_size > 0 else None,
            save_steps=200,
            output_dir=output_dir,
            save_total_limit=2,
            gradient_checkpointing=True, 
            gradient_checkpointing_kwargs={'use_reentrant':True},
            load_best_model_at_end=True if val_set_size > 0 else False,
            #ddp_find_unused_parameters=False if ddp else None,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
            do_train=True,
            remove_unused_columns=False,

        )
    ref_model=None
    trainer = SimPOTrainer(
        model=model,
        ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used
        #model_init_kwargs=model_kwargs,
        args=training_args,
        beta=beta,
        gamma=gamma,
        train_dataset=train_data,
        eval_dataset=val_data,
        tokenizer=tokenizer,
        max_length=cutoff_len,
        max_prompt_length=cutoff_len//2,
        #loss_type=training_args.loss_type,
    )
    signal.signal(signal.SIGINT, save_model)

    trainer.train()
    model.save_pretrained(output_dir)

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )

if __name__ == "__main__":
    fire.Fire(train)

output:{'loss': 1.7515, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.8618572354316711, 'rewards/rejected': -0.702195405960083, 'rewards/accuracies': 0.0833333358168602, 'rewards/margins': -0.15966185927391052, 'logps/rejected': -0.28087812662124634, 'logps/chosen': -0.34474286437034607, 'logits/rejected': 0.4580184519290924, 'logits/chosen': 0.4565271735191345, 'epoch': 0.06} {'loss': 1.6516, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -0.9827778935432434, 'rewards/rejected': -0.9460451006889343, 'rewards/accuracies': 0.4166666865348816, 'rewards/margins': -0.03673281893134117, 'logps/rejected': -0.3784180283546448, 'logps/chosen': -0.3931111693382263, 'logits/rejected': 0.47764071822166443, 'logits/chosen': 0.47824805974960327, 'epoch': 0.12} {'loss': 1.5422, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -1.096510648727417, 'rewards/rejected': -1.19998037815094, 'rewards/accuracies': 0.6000000238418579, 'rewards/margins': 0.10346974432468414, 'logps/rejected': -0.479992151260376, 'logps/chosen': -0.4386042654514313, 'logits/rejected': 0.5151220560073853, 'logits/chosen': 0.5120567679405212, 'epoch': 0.18} {'loss': 1.4713, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -1.1208069324493408, 'rewards/rejected': -1.3170217275619507, 'rewards/accuracies': 0.7500000596046448, 'rewards/margins': 0.1962147206068039, 'logps/rejected': -0.5268087387084961, 'logps/chosen': -0.4483228325843811, 'logits/rejected': 0.5599152445793152, 'logits/chosen': 0.5529050827026367, 'epoch': 0.24} {'loss': 1.2346, 'grad_norm': 0.0, 'learning_rate': 0.001, 'rewards/chosen': -1.534162163734436, 'rewards/rejected': -2.1145620346069336, 'rewards/accuracies': 0.800000011920929, 'rewards/margins': 0.5803996920585632, 'logps/rejected': -0.8458248376846313, 'logps/chosen': -0.6136649250984192, 'logits/rejected': 0.6048283576965332, 'logits/chosen': 0.5854496359825134, 'epoch': 0.3}

Minami-su commented 1 month ago

trl version 0.8.6,transformers version 4.41.1

younesbelkada commented 1 month ago

Thanks, will have a look !

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.