huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.03k stars 1.27k forks source link

KTO training produces NaN rewards #1447

Closed claralp closed 7 months ago

claralp commented 7 months ago

Within the training with KTO Trainer I occasionally experience nan values as rewards.
I am running the training as a job on Ms Azure with one GPU (NVIDIA A100 80GB PCIe).
Ultimately these issues cause my Azure job to crash and retry...

The log output I get from the KTOTrainer:

{'loss': 0.0202, 'grad_norm': 0.016206106171011925, 'learning_rate': 8.595282268751656e-06, 'rewards/chosen': 11.158143043518066, 'rewards/rejected': -29.0671443939209, 'rewards/margins': 40.22528839111328, 'kl': 0.0, 'logps/chosen': -15.192975044250488, 'logps/rejected': -180.1438446044922, 'epoch': 0.43}
{'loss': 0.0155, 'grad_norm': 4.091757774353027, 'learning_rate': 8.568778160614896e-06, 'rewards/chosen': 10.752923965454102, 'rewards/rejected': -26.606868743896484, 'rewards/margins': 37.35979461669922, 'kl': 0.0, 'logps/chosen': -13.974691390991211, 'logps/rejected': -156.9815673828125, 'epoch': 0.44}
{'loss': 0.0124, 'grad_norm': 0.06709074974060059, 'learning_rate': 8.542274052478135e-06, 'rewards/chosen': 10.838713645935059, 'rewards/rejected': -29.24416732788086, 'rewards/margins': 40.08287811279297, 'kl': 0.0, 'logps/chosen': -10.99155044555664, 'logps/rejected': -165.8121795654297, 'epoch': 0.44}
{'loss': 0.0113, 'grad_norm': 14.28693675994873, 'learning_rate': 8.515769944341374e-06, 'rewards/chosen': 11.07004451751709, 'rewards/rejected': -30.99440574645996, 'rewards/margins': 42.064453125, 'kl': 0.0, 'logps/chosen': -13.967004776000977, 'logps/rejected': -176.50094604492188, 'epoch': 0.45}
{'loss': 0.0193, 'grad_norm': 3.899095296859741, 'learning_rate': 8.489265836204611e-06, 'rewards/chosen': 10.825413703918457, 'rewards/rejected': -34.434303283691406, 'rewards/margins': 45.25971984863281, 'kl': 0.0, 'logps/chosen': -12.9598388671875, 'logps/rejected': -186.38381958007812, 'epoch': 0.46}
{'loss': 0.0109, 'grad_norm': 0.009407841600477695, 'learning_rate': 8.46276172806785e-06, 'rewards/chosen': nan, 'rewards/rejected': -33.95360565185547, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': nan, 'logps/rejected': -176.4713897705078, 'epoch': 0.47}
{'loss': 0.0324, 'grad_norm': 17.832523345947266, 'learning_rate': 8.43625761993109e-06, 'rewards/chosen': 10.286358833312988, 'rewards/rejected': -33.60068893432617, 'rewards/margins': 43.887046813964844, 'kl': 0.0, 'logps/chosen': -20.224634170532227, 'logps/rejected': -184.18112182617188, 'epoch': 0.48}
{'loss': 0.0029, 'grad_norm': 0.03802444413304329, 'learning_rate': 8.409753511794329e-06, 'rewards/chosen': 10.086004257202148, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -12.816671371459961, 'logps/rejected': nan, 'epoch': 0.48}
{'loss': 0.012, 'grad_norm': 2.815098524093628, 'learning_rate': 8.383249403657568e-06, 'rewards/chosen': 10.549690246582031, 'rewards/rejected': -31.304590225219727, 'rewards/margins': 41.85428237915039, 'kl': 0.0, 'logps/chosen': -13.178544998168945, 'logps/rejected': -169.447509765625, 'epoch': 0.49}
{'loss': 0.0074, 'grad_norm': 0.001768477726727724, 'learning_rate': 8.356745295520805e-06, 'rewards/chosen': 11.22235107421875, 'rewards/rejected': -33.09156799316406, 'rewards/margins': 44.31391906738281, 'kl': 0.0, 'logps/chosen': -13.94648265838623, 'logps/rejected': -178.08566284179688, 'epoch': 0.5}
{'loss': 0.0055, 'grad_norm': 8.117822647094727, 'learning_rate': 8.330241187384045e-06, 'rewards/chosen': 11.166982650756836, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -14.707374572753906, 'logps/rejected': nan, 'epoch': 0.51}
{'loss': 0.0206, 'grad_norm': 1.6973105669021606, 'learning_rate': 8.303737079247284e-06, 'rewards/chosen': 10.326757431030273, 'rewards/rejected': -33.753868103027344, 'rewards/margins': 44.08062744140625, 'kl': 0.0, 'logps/chosen': -19.15297508239746, 'logps/rejected': -181.1234130859375, 'epoch': 0.52}
{'loss': 0.0136, 'grad_norm': 9.740607261657715, 'learning_rate': 8.277232971110523e-06, 'rewards/chosen': 10.298160552978516, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -15.718103408813477, 'logps/rejected': nan, 'epoch': 0.52}

my pip freeze:

accelerate==0.28.0
aiohttp==3.9.3
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes==0.43.0
certifi==2024.2.2
charset-normalizer==3.3.2
datasets==2.18.0
dill==0.3.8
docstring_parser==0.16
filelock==3.13.1
frozenlist==1.4.1
fsspec==2024.2.0
huggingface-hub==0.21.4
idna==3.6
Jinja2==3.1.3
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pandas==2.2.1
peft==0.9.0
protobuf==5.26.0
psutil==5.9.8
pyarrow==15.0.1
pyarrow-hotfix==0.6
Pygments==2.17.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
rich==13.7.1
safetensors==0.4.2
sentencepiece==0.2.0
shtab==1.7.1
six==1.16.0
sympy==1.12
tokenizers==0.15.2
torch==2.2.1
tqdm==4.66.2
transformers==4.38.2
triton==2.2.0
trl @ git+https://github.com/huggingface/trl@a2aa0f0b09671eaf81a945eb5e4913165fee92fa
typing_extensions==4.10.0
tyro==0.7.3
tzdata==2024.1
urllib3==2.2.1
xxhash==3.4.1
yarl==1.9.4

the training script I use:

from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, BitsAndBytesConfig

from trl import KTOConfig, KTOTrainer, ModelConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
import torch

# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the KTO training script.
    """

    dataset_path: Optional[str] = field(default=None, metadata={"help": "the online dataset to use, should include keys: [prompt, completion, label] OR [messages, completion, label]"})
    data_files: Optional[str] = field(default=None, metadata={"help": "the file(s) including data to use, this looks for 'data/{data_files}_train/test.jsonl.gz'. Datasets should include keys: [prompt, completion, label] OR [messages, completion, label]"})
    file_type: Optional[str] = field(default=None, metadata={"help": "the file type to open, e.g. 'json', 'csv'"})
    max_tokens: Optional[str] = field(default=4096, metadata={"help": "the maximum number of tokens returned by the data collator"})
    # debugging
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})

if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
    script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
    print(f"train with {script_args}, \n{model_args}")

    # Peft & Quantisation
    quantization_config = BitsAndBytesConfig(load_in_8bit=model_args.load_in_8bit)
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout)

    # Load the trainable model
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
                                                 quantization_config = quantization_config,
                                                 torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
                                                 device_map = "auto")

    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config = peft_config)
    model.print_trainable_parameters()

    # Reference Model
    model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
                                                     quantization_config = quantization_config,
                                                     torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
                                                     device_map = "auto")

    model_ref = prepare_model_for_kbit_training(model_ref)
    model_ref = get_peft_model(model_ref, peft_config=peft_config)

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.truncation_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load the desired dataset
    if script_args.dataset_path != None:
        dataset = load_dataset(script_args.dataset_path)
    elif script_args.data_files != None and script_args.file_type != None:
        dataset = load_dataset(script_args.file_type, data_files={"train": f"./data/{script_args.data_files}_train.jsonl.gz", "test": f"./data/{script_args.data_files}_test.jsonl.gz"})    
    else:
        print("either dataset_path or data_files & file_type have to be defined")
        exit(1)

    if script_args.sanity_check == True:
        dataset["train"] = dataset["train"].select(range(1000))

    # Create Split if not existing already
    if "test" not in dataset:
        dataset = dataset["train"].train_test_split(train_size=0.9)

    # apply chat template if not preformatted
    if "prompt" not in dataset["train"].features:
        dataset = dataset.map(lambda x: {"prompt": tokenizer.apply_chat_template(x["messages"], tokenize=False, add_generation_prompt=False)})

    # Set max. lengths for DefaultDataCollator

    max_prompt_len, max_compl_len, max_len = 0, 0, 0
    tokenizer.model_max_length = script_args.max_tokens
    tokenizer.max_model_input_sizes = script_args.max_tokens

    for sample in dataset["train"]:

        compl_len = len(tokenizer(sample["completion"], truncation=True)["input_ids"])
        total_len = len(tokenizer(sample["prompt"] + sample["completion"], truncation=True)["input_ids"])
        prompt_len = total_len - compl_len

        max_prompt_len = max(max_prompt_len, prompt_len)
        max_compl_len = max(max_compl_len, compl_len)
        max_len = max(max_len, total_len)

    kto_args.max_prompt_length = max_prompt_len
    kto_args.max_completion_length = max_compl_len
    kto_args.max_length = max_len

    print(dataset)
    print(f"max_prompt_length={kto_args.max_prompt_length}, max_completion_length={kto_args.max_completion_length}, max_len={kto_args.max_length}")

    # set desired/undesired weights

    desired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == True)))
    undesired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == False)))

    kto_args.desirable_weight = desired_weight
    kto_args.undesirable_weight = undesired_weight

    # initialize the KTO trainer
    kto_trainer = KTOTrainer(
        model,
        model_ref,
        args=kto_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        tokenizer=tokenizer
    )

    # train
    kto_trainer.train()

the call arguments

python train_kto.py \
    --model_name_or_path DiscoResearch/DiscoLM_German_7b_v1 \
    --data_files wp_rag_kto_20k \
    --file_type json \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --num_train_epochs 3 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 2 \
    --logging_steps 10 \
    --evaluation_strategy epoch \
    --output_dir kto_finetuned \
    --optim adamw_bnb_8bit \
    --warmup_steps 10 \
    --logging_first_step \
    --use_peft \
    --lora_r 8 \
    --lora_alpha 16 \
    --report_to none \
    --disable_tqdm False \
    --beta 0.5 \
    --torch_dtype bfloat16 \
    --bf16 \
    --load_in_8bit

Maybe @lewtun can help

lewtun commented 7 months ago

cc also @kashif

kashif commented 7 months ago

@claralp depending on the batch-size it could be some of the metrics are nan, this should not effect the training etc. and special attention has been paid to make sure the loss etc. is robust to these nans when doing back-prop.

kashif commented 7 months ago

@claralp i do not think nans in a dict should cause this to crash... do you have some crash back-traces?

claralp commented 7 months ago

@kashif there are no errors or warnings in the stdout/stderr, it just stops at some point after the nan rewards appear, so I cannot provide a stack trace here.
However, the Azure execution wrapper log shows a blocking process:

2024-03-19T03:33:30.165457Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution::process_manager: Failed blocking user process detected, process name: echo, process pid: 34, code: None success_return_code=Zero { additional_codes: [] } code=None
2024-03-19T03:33:31.167084Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution: Execution process terminated by a signal, which may be due to failure in other user processes on the same node or node ran out of memory. local_rank=0 name=echo

lifecycler log shows only a Preemption signal:

2024-03-19T03:33:29.494161Z  WARN run_lifecycler:run_service_and_step_through_lifecycle:step_through_lifecycle: lifecycler::lifecycle: Received abort message, exiting lifecycle abort_message=AbortMessage { error: Some(Error { code: "ReceivedPreemptionSignal", message: "{\"Compliant\":\"Job was terminated due to: Runtime received a preemption signal.\"}", target: "", node_info: None, category: UserError, error_details: [], inner_error: None }), broadcast_abort: true, request_timeout: 25 }
PhilipMay commented 7 months ago

I think this is could be the "normal" low-prioity Azure preemption? :-(

claralp commented 7 months ago

Important note here: The crash only appears after the training shows nan values. Otherwise it doesn't.
I even saw cases where all results converge to nan values

{'loss': 0.0, 'grad_norm': 281.6248474121094, 'learning_rate': 9.856115107913668e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.17875319719314575, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 192.08326721191406, 'learning_rate': 9.848121502797762e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0570355653762817, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 33.55568313598633, 'learning_rate': 9.840127897681853e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.1016669273376465, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 44.5154914855957, 'learning_rate': 9.832134292565947e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.197722911834717, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 10.592936515808105, 'learning_rate': 9.82414068745004e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0713751316070557, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
{'loss': 0.0, 'grad_norm': 61.1552734375, 'learning_rate': 9.81614708233413e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3863883912563324, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}

Could there be anything wrong with the hyperparameter choice, @kashif ?

kashif commented 7 months ago

@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?

also does this happen if you try locally outside of the azure

claralp commented 7 months ago

The output below is from a test with very unbalanced data, namely 2k desired completions and 10k undesired ones.
I know that a ratio between 4:3 and 1:1 is required for proper training.
This is just an experiment to see if missing pos/neg samples in a batch might be the reason behind nan values as rewards. But here I get nan losses even without nan rewards...

{'loss': 1.0431, 'grad_norm': 42.099464416503906, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/margins': 0.0, 'kl': 0.0, 'logps/chosen': -37.16696548461914, 'logps/rejected': -87.62107849121094, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 41.9438362121582, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -32.92508316040039, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 29.28327178955078, 'learning_rate': 3e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.15479230880737305, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.70748519897461, 'learning_rate': 4.000000000000001e-06, 'rewards/chosen': 0.06518054008483887, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.43951892852783203, 'logps/chosen': -31.101844787597656, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 44.989227294921875, 'learning_rate': 5e-06, 'rewards/chosen': 0.3087962865829468, 'rewards/rejected': 0.23543643951416016, 'rewards/margins': 0.07335984706878662, 'kl': 1.230994462966919, 'logps/chosen': -32.83413314819336, 'logps/rejected': -74.81724548339844, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 55.32667541503906, 'learning_rate': 6e-06, 'rewards/chosen': 0.3336696922779083, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3016533851623535, 'logps/chosen': -38.598453521728516, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.44403839111328, 'learning_rate': 7e-06, 'rewards/chosen': 0.8524215817451477, 'rewards/rejected': 0.5893988609313965, 'rewards/margins': 0.2630227208137512, 'kl': 0.7648882865905762, 'logps/chosen': -35.86614227294922, 'logps/rejected': -93.13447570800781, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 26.85154914855957, 'learning_rate': 8.000000000000001e-06, 'rewards/chosen': 0.8056153059005737, 'rewards/rejected': 0.40718716382980347, 'rewards/margins': 0.39842814207077026, 'kl': 1.3891675472259521, 'logps/chosen': -34.07681655883789, 'logps/rejected': -113.53411102294922, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 25.181703567504883, 'learning_rate': 9e-06, 'rewards/chosen': nan, 'rewards/rejected': 0.9289813041687012, 'rewards/margins': nan, 'kl': 1.279036521911621, 'logps/chosen': nan, 'logps/rejected': -132.0060272216797, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 36.62141799926758, 'learning_rate': 1e-05, 'rewards/chosen': 1.4094278812408447, 'rewards/rejected': 0.8396401405334473, 'rewards/margins': 0.5697878003120422, 'kl': 2.0255985260009766, 'logps/chosen': -30.87615394592285, 'logps/rejected': -102.92286682128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.035221099853516, 'learning_rate': 9.997300215982722e-06, 'rewards/chosen': 1.5928469896316528, 'rewards/rejected': 1.5922844409942627, 'rewards/margins': 0.0005625784397125244, 'kl': 2.884922981262207, 'logps/chosen': -39.46299362182617, 'logps/rejected': -121.78970336914062, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 33.07608413696289, 'learning_rate': 9.994600431965443e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.1301448345184326, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.48128128051758, 'learning_rate': 9.991900647948165e-06, 'rewards/chosen': 2.113973617553711, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.475428819656372, 'logps/chosen': -26.679065704345703, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 31.501819610595703, 'learning_rate': 9.989200863930886e-06, 'rewards/chosen': 2.6266024112701416, 'rewards/rejected': 2.2295963764190674, 'rewards/margins': 0.3970060348510742, 'kl': 4.643209934234619, 'logps/chosen': -42.25154495239258, 'logps/rejected': -95.91471862792969, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 34.09553527832031, 'learning_rate': 9.986501079913607e-06, 'rewards/chosen': 2.7660703659057617, 'rewards/rejected': 2.6509010791778564, 'rewards/margins': 0.11516910791397095, 'kl': 4.8384199142456055, 'logps/chosen': -49.93422317504883, 'logps/rejected': -73.00190734863281, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.591957092285156, 'learning_rate': 9.983801295896329e-06, 'rewards/chosen': 3.131122350692749, 'rewards/rejected': 2.9620559215545654, 'rewards/margins': 0.1690664291381836, 'kl': 4.498130798339844, 'logps/chosen': -29.836196899414062, 'logps/rejected': -105.75230407714844, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 13.737163543701172, 'learning_rate': 9.98110151187905e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.1204824447631836, 'rewards/margins': nan, 'kl': 6.049262523651123, 'logps/chosen': nan, 'logps/rejected': -96.40724182128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.375396728515625, 'learning_rate': 9.978401727861771e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.636046886444092, 'rewards/margins': nan, 'kl': 6.3599958419799805, 'logps/chosen': nan, 'logps/rejected': -97.00442504882812, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 27.26076889038086, 'learning_rate': 9.975701943844493e-06, 'rewards/chosen': 4.384129524230957, 'rewards/rejected': 3.9822707176208496, 'rewards/margins': 0.40185898542404175, 'kl': 8.23063850402832, 'logps/chosen': -24.248661041259766, 'logps/rejected': -105.89572143554688, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 18.513507843017578, 'learning_rate': 9.973002159827214e-06, 'rewards/chosen': 4.265963077545166, 'rewards/rejected': 3.8863425254821777, 'rewards/margins': 0.3796207308769226, 'kl': 6.635190010070801, 'logps/chosen': -24.802963256835938, 'logps/rejected': -68.99553680419922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.997692108154297, 'learning_rate': 9.970302375809935e-06, 'rewards/chosen': 5.037494659423828, 'rewards/rejected': 4.227317810058594, 'rewards/margins': 0.8101770877838135, 'kl': 8.07493782043457, 'logps/chosen': -24.345657348632812, 'logps/rejected': -74.88150024414062, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 26.245861053466797, 'learning_rate': 9.967602591792658e-06, 'rewards/chosen': 4.526309490203857, 'rewards/rejected': 4.603299140930176, 'rewards/margins': -0.07698965072631836, 'kl': 8.698637008666992, 'logps/chosen': -22.94290542602539, 'logps/rejected': -99.22356414794922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 22.14063835144043, 'learning_rate': 9.964902807775378e-06, 'rewards/chosen': 5.355809211730957, 'rewards/rejected': 4.891297340393066, 'rewards/margins': 0.464511513710022, 'kl': 8.954204559326172, 'logps/chosen': -23.850910186767578, 'logps/rejected': -87.7445068359375, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.642059326171875, 'learning_rate': 9.962203023758101e-06, 'rewards/chosen': 5.606294631958008, 'rewards/rejected': 6.807004928588867, 'rewards/margins': -1.2007099390029907, 'kl': 9.733396530151367, 'logps/chosen': -24.039264678955078, 'logps/rejected': -119.2092514038086, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 10.412492752075195, 'learning_rate': 9.959503239740822e-06, 'rewards/chosen': 5.953470230102539, 'rewards/rejected': 5.025949954986572, 'rewards/margins': 0.9275206327438354, 'kl': 10.74533462524414, 'logps/chosen': -16.727996826171875, 'logps/rejected': -80.9796142578125, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 17.695709228515625, 'learning_rate': 9.956803455723542e-06, 'rewards/chosen': nan, 'rewards/rejected': 6.109594345092773, 'rewards/margins': nan, 'kl': 11.900070190429688, 'logps/chosen': nan, 'logps/rejected': -121.30842590332031, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 35.035892486572266, 'learning_rate': 9.954103671706265e-06, 'rewards/chosen': 6.687896251678467, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 12.4317626953125, 'logps/chosen': -16.511333465576172, 'logps/rejected': nan, 'epoch': 0.02}
claralp commented 7 months ago

kashif commented 1 hour ago @claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?

batch size is 8 and gradient accumulation steps is 2 as in the config above

also does this happen if you try locally outside of the azure

currently checking this

claralp commented 7 months ago

closed with #1499 and #1514