unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.28k stars 1.19k forks source link

NaN during llama3 finetuning #427

Open mano3-1 opened 5 months ago

mano3-1 commented 5 months ago

Hi,

I'm currently fine-tuning llama3-instruct-8b on a custom dataset using unsloth's FastLanguageModel. I'm using Hugging Face's SFTTrainer to train the model. Surprisingly, the gradient norm and evaluation loss become NaN after a few steps. I've seen a blog from unsloth mentioning that NaNs may appear due to a bug, but they also mentioned that the bug was fixed by Hugging Face and unsloth now (here, under the llama3-Quirks section). So, I not only updated unsloth and Hugging Face but also added the "pad_token" mentioned in the blog. Despite these attempts, the NaN problem still persists. Is there something else that I'm missing? Can someone help me out with this?

Here's the code snippet of how I'm loading the model:

 model, tokenizer = FastLanguageModel.from_pretrained(
      model_name = model_name,
      max_seq_length = args.max_seq_length,
      dtype = compute_dtype,
      load_in_4bit = args.use_4bit,
      # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
  )
  tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
  model.config.pad_token_id = tokenizer.pad_token_id # updating model config
  tokenizer.padding_side = 'right
  model = FastLanguageModel.get_peft_model(
      model,
      r = args.lora_r, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
      target_modules = lora_modules,
      lora_alpha = args.lora_alpha,
      lora_dropout = args.lora_dropout, # Supports any, but = 0 is optimized
      bias = args.lora_bias,    # Supports any, but = "none" is optimized
      # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
      use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
      random_state = 3407,
      use_rslora = False,  # We support rank stabilized LoRA
      loftq_config = None, # And LoftQ
  )

Following is the training code:

  training_arguments = TrainingArguments(
      output_dir=output_dir,
      num_train_epochs=args.epochs,
      per_device_train_batch_size=args.per_device_train_batch_size,
      per_device_eval_batch_size=args.per_device_eval_batch_size,
      gradient_accumulation_steps=args.gradient_accumulation_steps,
      optim=args.optimizer,
      save_steps=args.save_steps,
      logging_steps=args.logging_steps,
      learning_rate=args.learning_rate,
      weight_decay=args.weight_decay,
      fp16=fp16,
      bf16=bf16,
      max_grad_norm=args.max_grad_norm,
      max_steps=args.max_steps,
      warmup_ratio=args.warmup_ratio,
      # group_by_length=args.group_by_length,
      lr_scheduler_type=args.lr_scheduler_type,
      logging_strategy="steps",
      report_to="tensorboard",
      evaluation_strategy="steps",
      # ddp_find_unused_parameters=False,
  )
  trainer = SFTTrainer(
      model=model,
      train_dataset=train_dataset,
      eval_dataset=eval_dataset,
      dataset_text_field="chats",
      max_seq_length=args.max_seq_length,
      tokenizer=tokenizer,
      args=training_arguments,
      packing=packing
  )
danielhanchen commented 5 months ago

Are you training on embed_tokens and lm_head?

mano3-1 commented 5 months ago

Hi @danielhanchen,

Thank you for your response. I'm unsure about the inner workings of get_peft_model in Unsloth, but assuming it functions similarly to other peft methods, it should freeze the base model, including the embedding matrix, correct? Consequently, I believe my scripts are only training the Lora parameters. I attempted to use Unsloth's fix_untrained_tokens, but it didn't work out for me. Additionally, I noticed that Unsloth's blog mentions the llama-3-8b base model, whereas I'm using the llama-3-8b-instruct model. Instruct model's reserved tokens should not arise any issues as they are finetuned (unlike base model) right?

lapp0 commented 5 months ago

@mano3-1 what does the traceback say if you run

with torch.autograd.detect_anomaly():
    trainer.train()
mano3-1 commented 5 months ago

Hi @lapp0, Here is the traceback:

Traceback (most recent call last):
  File "/home/ubuntu/LLMOps/train/train.py", line 501, in <module>
    main()
  File "/home/ubuntu/LLMOps/train/train.py", line 497, in main
    training_function(args)
  File "/home/ubuntu/LLMOps/train/train.py", line 445, in training_function
    trainer.train()
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "<string>", line 361, in _fast_inner_training_loop
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/transformers/trainer.py", line 3147, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/accelerate/accelerator.py", line 2013, in backward
    loss.backward(**kwargs)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'Fast_CrossEntropyLossBackward' returned nan values in its 0th output.
lapp0 commented 5 months ago

I'm running into issues with back-propagation in unsloth as well, albeit I'm using a custom loss function and Mistral instead of llama-3. It works fine for AutoModelForCausalLM & get_peft_model, but with unsloth I get

`RuntimeError: Function 'LoRA_MLPBackward' returned nan values in its 0th output.

  File "<string>", line 361, in _fast_inner_training_loop
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/policy_trainer_base.py", line 549, in training_step
    return super().training_step(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3147, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2013, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 142, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/_utils.py", line 348, in backward
    torch.autograd.backward(output, dY)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'LoRA_MLPBackward' returned nan values in its 0th output.

I'd be interested in the cause of your issue, perhaps it is the same as mine. If I figure anything out with mine I'll let you know.

mano3-1 commented 5 months ago

Hi @lapp0 Seems like we both are facing similar issue. I tried removing unsloth from my code and trained it with huggingface utilities, it went well. But I seriously want to have this unsloth in the loop, because the memory boost is significant. Do you think this is from unsloth's side or something which is popping due to our scripts?

lapp0 commented 5 months ago

I'm not sure. Your backwards step where it fails is a different layer of the model than me, but the only thing our scripts have in common is unsloth.

How about some debug details?

1) Could you please share a full reproduction script, which would allow me and daniel to run locally? This includes the whole source file along with your run command.

2) What is the output of pip3 freeze

mano3-1 commented 5 months ago

Here is the pip freeze: requirements.txt

Here is the full training script: link

This is how I trigger the training scripts: python train.py --max_seq_length 4000 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --sm_train_dir "/opt/ml/processing/train" --sm_validation_dir "/opt/ml/processing/test" --hf_token <yourtoken> --run_experiment False --lora_r 32 --lora_alpha 8 --unsloth True --logging_steps 8 --save_steps 8

you may set hf_token to string "None", if you are loading unsloth models I guess.

lapp0 commented 5 months ago

requirements.txt isn't the same as pip freeze. pip3 freeze will detail the version of all packages.

danielhanchen commented 5 months ago

Oh no sorry guys - i will take a look

lapp0 commented 5 months ago

Thanks @danielhanchen

Here is my reproduction script as well, run on a 4090 with cuda 12.1. @mano3-1 has a standard SFT script so his is probably worth looking at first.

"""
INSTALL DEPS:
pip install torch==2.3.0
pip install transformers tensorboardX bitsandbytes peft accelerate flash_attn --upgrade
pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"
pip install "git+https://github.com/lapp0/trl.git@ppov2"
pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
pip install torch==2.3.0  # ensure correct torch still used
"""
import multiprocessing

from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedModel,
    DataCollatorWithPadding,
    BitsAndBytesConfig
)
import torch
from trl.trainer.ppov2_trainer import PPOConfig, PPOTrainer, PolicyAndValueWrapper
from peft import get_peft_model, LoraConfig

base_model_uri = "HuggingFaceH4/mistral-7b-sft-beta"
reward_model_uri = "weqweasdas/RM-Mistral-7B"

################
# Model & Tokenizer
################
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_uri,
    padding_side="left",
    trust_remote_code=True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
    reward_model_uri,
    num_labels=1,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
    reward_model_uri,
    num_labels=1,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)
value_model = get_peft_model(
    value_model,
    LoraConfig(
        r=16,
        lora_alpha=64,
        lora_dropout=0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )
)

from unsloth import FastLanguageModel
base_policy, _ = FastLanguageModel.from_pretrained(
    model_name=base_model_uri,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)
base_policy = FastLanguageModel.get_peft_model(
    base_policy,
    r=16,
    lora_alpha=64,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
    max_seq_length=2048
)
"""
# Creating base_policy like this works, unsloth doesn't
from transformers import AutoModelForCausalLM
base_policy = AutoModelForCausalLM.from_pretrained(
    base_model_uri,
    num_labels=1,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
    attn_implementation="flash_attention_2",
)
lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
base_policy = get_peft_model(base_policy, lora_config)
"""

# trl.trainer.peft_module_casting_to_bf16(base_model)

base_model = PolicyAndValueWrapper(base_policy, value_model)

################
# Dataset
################
raw_datasets = load_dataset("HuggingFaceH4/ultrachat_200k")
train_dataset = raw_datasets["train_sft"]
eval_dataset = raw_datasets["test_sft"]

def prepare_dataset(dataset, tokenizer):
    """pre-tokenize the dataset before training; only collate during training"""

    def tokenize(element):
        input_ids = tokenizer.apply_chat_template(
            element["messages"][:1],
            padding=False,
            add_generation_prompt=True,
        )
        return {"input_ids": input_ids, "lengths": len(input_ids)}

    return dataset.map(
        tokenize,
        remove_columns=dataset.column_names,
        num_proc=multiprocessing.cpu_count(),
        load_from_cache_file=False,
    )

train_dataset = prepare_dataset(train_dataset, tokenizer).filter(lambda x: x["lengths"] <= 1024)
eval_dataset = prepare_dataset(eval_dataset, tokenizer).filter(lambda x: x["lengths"] <= 1024)

collator = DataCollatorWithPadding(tokenizer)

###############
# Training
################
config = PPOConfig(
    output_dir="./ppov2_experiment_v2",
    report_to="tensorboard",
    update_generation_steps=16,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=2,
    push_to_hub=True,
    hub_model_id="lapp0/ppov2_experiment_v2",
    logging_steps=1,
    learning_rate=3e-6,
    save_steps=4,
    non_eos_penalty=True,
    response_length=128,
    optim="paged_adamw_8bit",
    bf16=True,
    fp16=False,
    truncate_token="eos",
    gradient_checkpointing=True,
    # gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = PPOTrainer(
    model=base_model,
    args=config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    reward_model=reward_model,
    data_collator=collator,
    tokenizer=tokenizer,
)
with torch.autograd.detect_anomaly():
    trainer.train()

trainer.generate_completions()

pip3 freeze:

``` accelerate==0.30.0 aiohttp==3.9.5 aiosignal==1.3.1 anaconda-anon-usage @ file:///croot/anaconda-anon-usage_1710965072196/work anyio==4.3.0 archspec @ file:///croot/archspec_1709217642129/work argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work astunparse==1.6.3 async-lru==2.0.4 async-timeout==4.0.3 attrs @ file:///croot/attrs_1695717823297/work Babel==2.15.0 bash_kernel==0.9.3 beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work bitsandbytes==0.43.1 bleach==6.1.0 boltons @ file:///croot/boltons_1677628692245/work Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work certifi @ file:///croot/certifi_1707229174982/work/certifi cffi @ file:///croot/cffi_1700254295673/work chardet @ file:///home/builder/ci_310/chardet_1640804867535/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click @ file:///croot/click_1698129812380/work comm==0.2.2 conda @ file:///croot/conda_1689269889729/work conda-build @ file:///croot/conda-build_1710789183177/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1691418897561/work/src conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_index @ file:///croot/conda-index_1706633791028/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work cryptography @ file:///croot/cryptography_1710350347627/work datasets==2.19.1 debugpy==1.8.1 decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work defusedxml==0.7.1 dill==0.3.8 distro @ file:///croot/distro_1701455004953/work dnspython==2.6.1 docstring_parser==0.16 einops==0.8.0 exceptiongroup @ file:///croot/exceptiongroup_1706031385326/work executing @ file:///opt/conda/conda-bld/executing_1646925071911/work expecttest==0.2.1 fastjsonschema==2.19.1 filelock @ file:///croot/filelock_1700591183607/work flash-attn==2.5.8 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.3.1 gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work h11==0.14.0 httpcore==1.0.5 httpx==0.27.0 huggingface-hub==0.23.0 hypothesis==6.100.1 idna @ file:///croot/idna_1666125576474/work iniconfig==2.0.0 ipykernel==6.29.4 ipython @ file:///croot/ipython_1704833016303/work ipywidgets==8.1.2 isoduration==20.11.0 jedi @ file:///tmp/build/80754af9/jedi_1644315229345/work Jinja2 @ file:///croot/jinja2_1706733616596/work json5==0.9.25 jsonpatch @ file:///croot/jsonpatch_1710807507480/work jsonpointer==2.1 jsonschema @ file:///croot/jsonschema_1699041609003/work jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work jupyter==1.0.0 jupyter-archive==3.4.0 jupyter-console==6.6.3 jupyter-events==0.10.0 jupyter-http-over-ws==0.0.8 jupyter-lsp==2.2.5 jupyter_client==8.6.1 jupyter_core==5.7.2 jupyter_server==2.14.0 jupyter_server_terminals==0.5.3 jupyterlab==4.1.8 jupyterlab_pygments==0.3.0 jupyterlab_server==2.27.1 jupyterlab_widgets==3.0.10 lark==1.1.9 libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work libmambapy @ file:///croot/mamba-split_1712091911343/work/libmambapy markdown-it-py==3.0.0 MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work mdurl==0.1.2 menuinst @ file:///croot/menuinst_1706732933928/work mistune==3.0.2 mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 more-itertools @ file:///croot/more-itertools_1700662129964/work mpmath @ file:///croot/mpmath_1690848262763/work multidict==6.0.5 multiprocess==0.70.16 nbclient==0.10.0 nbconvert==7.16.4 nbformat==5.10.4 nbzip==0.1.0 nest-asyncio==1.6.0 networkx @ file:///croot/networkx_1690561992265/work ninja==1.11.1.1 notebook==7.1.3 notebook_shim==0.2.4 numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 optree==0.11.0 overrides==7.7.0 packaging @ file:///croot/packaging_1710807400464/work pandas==2.2.2 pandocfilters==1.5.1 parso @ file:///opt/conda/conda-bld/parso_1641458642106/work peft==0.10.0 pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pillow @ file:///croot/pillow_1707233021655/work pkginfo @ file:///croot/pkginfo_1679431160147/work platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy==1.5.0 prometheus_client==0.20.0 prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work protobuf==3.20.3 psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work pyarrow==16.0.0 pyarrow-hotfix==0.6 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work Pygments @ file:///croot/pygments_1684279966437/work pyOpenSSL @ file:///croot/pyopenssl_1708380408460/work PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work pytest==8.2.0 python-dateutil==2.9.0.post0 python-etcd==0.4.5 python-json-logger==2.0.7 pytz @ file:///croot/pytz_1695131579487/work PyYAML @ file:///croot/pyyaml_1698096049011/work pyzmq==26.0.3 qtconsole==5.5.2 QtPy==2.4.1 referencing @ file:///croot/referencing_1699012038513/work regex==2024.4.28 requests @ file:///croot/requests_1707355572290/work rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.1 rpds-py @ file:///croot/rpds-py_1698945930462/work ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work safetensors==0.4.3 Send2Trash==1.8.3 sentencepiece==0.2.0 shtab==1.7.1 six @ file:///tmp/build/80754af9/six_1644875935023/work sniffio==1.3.1 sortedcontainers==2.4.0 soupsieve @ file:///croot/soupsieve_1696347547217/work stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work sympy @ file:///croot/sympy_1701397643339/work tensorboardX==2.6.2.2 terminado==0.18.1 tinycss2==1.3.0 tokenizers==0.19.1 tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work toolz @ file:///croot/toolz_1667464077321/work torch==2.3.0+cu121 torchaudio==2.3.0 torchelastic==0.2.2 torchvision==0.18.0 tornado==6.4 tqdm @ file:///croot/tqdm_1679561862951/work traitlets @ file:///croot/traitlets_1671143879854/work transformers==4.40.2 triton==2.3.0 trl @ git+https://github.com/lapp0/trl.git@649aff0d142987b9e6a9ecea7ece562074d3f7c6 truststore @ file:///croot/truststore_1695244293384/work types-dataclasses==0.6.6 types-python-dateutil==2.9.0.20240316 typing_extensions==4.11.0 tyro==0.8.3 tzdata==2024.1 unsloth @ git+https://github.com/unslothai/unsloth.git@a93a885c286934c9c7467324054ca3f9d526a2bd uri-template==1.3.0 urllib3 @ file:///croot/urllib3_1707770551213/work wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work webcolors==1.13 webencodings==0.5.1 websocket-client==1.8.0 widgetsnbextension==4.0.10 xformers==0.0.26.post1 xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///croot/zstandard_1677013143055/work ```
mano3-1 commented 5 months ago

Hi @lapp0 , Although I named it requirements.txt, I have extracted it by doing pip freeze. Kindly check the file, you will find versions of all the libraries

lapp0 commented 5 months ago

Sorry about my confusion @mano3-1

I reviewed and compared our installed packages. Nothing noteworthy in the shared dependencies, other than perhaps the issue is related to the use of xformers. Will experiment with this later.

``` markdown-it-py==3.0.0 nvidia-curand-cu12==10.3.2.106 beautifulsoup4 parso nvidia-cudnn-cu12==8.9.2.26 tokenizers==0.19.1 unsloth Pygments ptyprocess nvidia-cuda-runtime-cu12==12.1.105 typing_extensions==4.11.0 matplotlib-inline decorator pure-eval nvidia-cusparse-cu12==12.1.0.106 PyYAML tomli sentencepiece==0.2.0 six async-timeout==4.0.3 prompt-toolkit jsonschema soupsieve referencing nvidia-cufft-cu12==11.0.2.54 PySocks traitlets mdurl==0.1.2 fsspec==2024.3.1 Brotli xxhash==3.4.1 nvidia-cublas-cu12==12.1.3.1 tyro==0.8.3 platformdirs packaging pycparser cffi protobuf==3.20.3 pyarrow-hotfix==0.6 nvidia-cusolver-cu12==11.4.5.107 wcwidth nvidia-cuda-nvrtc-cu12==12.1.105 asttokens jedi safetensors==0.4.3 exceptiongroup aiosignal==1.3.1 nvidia-nvjitlink-cu12==12.4.127 nvidia-cuda-cupti-cu12==12.1.105 bitsandbytes==0.43.1 rpds-py jsonschema-specifications pexpect nvidia-nvtx-cu12==12.1.105 peft==0.10.0 ```
danielhanchen commented 5 months ago

Thanks for the code repro - will test this out - sorry on the issue again!

DementedWeasel1971 commented 5 months ago

Also facing same issue. While using colab and the standard notebook in the unsloth folder. Thought to add.

mano3-1 commented 5 months ago

hey, I'm curious if someone has figured out a fix to this?

danielhanchen commented 5 months ago

Sorry guys just started debugging this. I also updated Unsloth, so maybe it might be better (hopefully). For local installations, please update Unsloth via

pip uninstall unsloth -y
pip install --upgrade --force-reinstall --no-cache-dir git+https://github.com/unslothai/unsloth.git

For Colab / Kaggle should be fine with a restart

@DementedWeasel1971 When you said the colab notebook we provided broke, could you point to exactly which one thanks.

@mano3-1 Extremely weird actually - I reran Colab with Instruct and it seems fine - would you be able to run just the conversational notebook for Llama-3 here: https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing

@lapp0 I'm currently running your PPO example here: https://colab.research.google.com/drive/1fgJv0eKlRKexOl2RqcxoiZ-HhGrdNWQW?usp=sharing (will wait for it to complete)

lapp0 commented 5 months ago

Thank so much for looking into it! Unfortunately I'm still getting nan on the first training step:

{'loss': 1.9125, 'grad_norm': nan, 'learning_rate': 2.9999207167208437e-06, 'objective/kl': 0.0, 'objective/entropy': 99.8125, 'objective/non_score_reward': 0.0, 'objective/rlhf_reward': -0.58380126953125, 'objective/scores': -0.58380126953125, 'policy/approxkl_avg': 0.0, 'policy/clipfrac_avg': 0.0, 'loss/policy_avg': -6.116561479529992e-09, 'loss/value_avg': 19.12525177001953, 'val/clipfrac_avg': 0.0011788890697062016, 'val/num_eos_tokens': 0.5, 'timer/training_step': 2.293384313583374, 'epoch': 0.0}

Please let me know if there's any other debug details that would help.

Also fyi, to speed up debugging you can set update_generation_steps=1.

Edit:

I pushed a bad commit to my branch, I reverted the broken change. Should be good to try again with head of https://github.com/lapp0/trl.git@ppov2.

mano3-1 commented 5 months ago

Hi, I followed @danielhanchen 's notebook and compared the parameters with mine. When I change the optimizer from paged_adamw_32bit to "adamw_8bit", the nan issues are not coming up.

@lapp0 I can see paged adam in your script, perhaps change it to adamw_8bit and try it over.

danielhanchen commented 5 months ago

@mano3-1 I changed it, but it seems like it's still nans weirdly. I also tried installing the February release of Unsloth with torch==2.1.1, and it's not working.

@lapp0 Do you know when it was last working? (which Unsloth version)

lapp0 commented 5 months ago

@mano3-1 I just tried using a non-paged optimizer as you suggested, but unfortunately it didn't resolve the issue.

@danielhanchen this is a brand new trainer adapted from https://github.com/huggingface/trl/pull/1540 based on https://arxiv.org/pdf/2403.17031

It hasn't ever been run successfully with unsloth before, but runs with peft + BnB. Shouldn't the forward and backward pass be identical to peft + BnB, or are there some steps where precision loss occurs?

@mano3-1 @danielhanchen it's interesting that mano isn't getting nan, but you are. Perhaps there is something different between your environment?

Here's mine for context:

``` accelerate==0.30.1 aiohttp==3.9.5 aiosignal==1.3.1 anaconda-anon-usage @ file:///croot/anaconda-anon-usage_1710965072196/work archspec @ file:///croot/archspec_1709217642129/work asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work astunparse==1.6.3 async-timeout==4.0.3 attrs @ file:///croot/attrs_1695717823297/work beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work bitsandbytes==0.43.1 boltons @ file:///croot/boltons_1677628692245/work Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work certifi @ file:///croot/certifi_1707229174982/work/certifi cffi @ file:///croot/cffi_1700254295673/work chardet @ file:///home/builder/ci_310/chardet_1640804867535/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click @ file:///croot/click_1698129812380/work conda @ file:///croot/conda_1689269889729/work conda-build @ file:///croot/conda-build_1710789183177/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1691418897561/work/src conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_index @ file:///croot/conda-index_1706633791028/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work cryptography @ file:///croot/cryptography_1710350347627/work datasets==2.19.1 decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work dill==0.3.8 distro @ file:///croot/distro_1701455004953/work dnspython==2.6.1 docstring_parser==0.16 einops==0.8.0 exceptiongroup @ file:///croot/exceptiongroup_1706031385326/work executing @ file:///opt/conda/conda-bld/executing_1646925071911/work expecttest==0.2.1 filelock @ file:///croot/filelock_1700591183607/work flash-attn==2.5.7 frozenlist==1.4.1 fsspec==2024.3.1 gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work huggingface-hub==0.23.0 hypothesis==6.100.1 idna @ file:///croot/idna_1666125576474/work ipython @ file:///croot/ipython_1704833016303/work jedi @ file:///tmp/build/80754af9/jedi_1644315229345/work Jinja2 @ file:///croot/jinja2_1706733616596/work jsonpatch @ file:///croot/jsonpatch_1710807507480/work jsonpointer==2.1 jsonschema @ file:///croot/jsonschema_1699041609003/work jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work lark==1.1.9 libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work libmambapy @ file:///croot/mamba-split_1712091911343/work/libmambapy markdown-it-py==3.0.0 MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work mdurl==0.1.2 menuinst @ file:///croot/menuinst_1706732933928/work mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 more-itertools @ file:///croot/more-itertools_1700662129964/work mpmath @ file:///croot/mpmath_1690848262763/work multidict==6.0.5 multiprocess==0.70.16 networkx @ file:///croot/networkx_1690561992265/work ninja==1.11.1.1 numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 optree==0.11.0 packaging @ file:///croot/packaging_1710807400464/work pandas==2.2.2 parso @ file:///opt/conda/conda-bld/parso_1641458642106/work peft==0.11.1 pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pillow @ file:///croot/pillow_1707233021655/work pkginfo @ file:///croot/pkginfo_1679431160147/work platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy @ file:///tmp/build/80754af9/pluggy_1648024709248/work prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work protobuf==3.20.3 psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work pyarrow==16.1.0 pyarrow-hotfix==0.6 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work Pygments @ file:///croot/pygments_1684279966437/work pyOpenSSL @ file:///croot/pyopenssl_1708380408460/work PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work python-dateutil==2.9.0.post0 python-etcd==0.4.5 pytz @ file:///croot/pytz_1695131579487/work PyYAML @ file:///croot/pyyaml_1698096049011/work referencing @ file:///croot/referencing_1699012038513/work regex==2024.5.15 requests @ file:///croot/requests_1707355572290/work rich==13.7.1 rpds-py @ file:///croot/rpds-py_1698945930462/work ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work safetensors==0.4.3 sentencepiece==0.2.0 shtab==1.7.1 six @ file:///tmp/build/80754af9/six_1644875935023/work sortedcontainers==2.4.0 soupsieve @ file:///croot/soupsieve_1696347547217/work stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work sympy @ file:///croot/sympy_1701397643339/work tensorboardX==2.6.2.2 tokenizers==0.19.1 tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work toolz @ file:///croot/toolz_1667464077321/work torch==2.1.0 torchaudio==2.3.0 torchelastic==0.2.2 torchvision==0.18.0 tqdm @ file:///croot/tqdm_1679561862951/work traitlets @ file:///croot/traitlets_1671143879854/work transformers==4.40.2 triton==2.1.0 trl @ git+https://github.com/lapp0/trl.git@3e681b6756d4f73283100d51f2ea20578d4f969a truststore @ file:///croot/truststore_1695244293384/work types-dataclasses==0.6.6 typing_extensions==4.11.0 tyro==0.8.4 tzdata==2024.1 unsloth @ git+https://github.com/unslothai/unsloth.git@d1d57ff99079d0ada0fde31cb67c637dd7ac27cc urllib3 @ file:///croot/urllib3_1707770551213/work wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl#sha256=44c33373976705b1f3c5729a5ed24165b21536e3d3eedc58dd60ce68d3603f89 xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///croot/zstandard_1677013143055/work ```
danielhanchen commented 5 months ago

@mano3-1 Wait if it works for you - weird it might be a weird paged optimizer issue.

@lapp0 HMm very weird indeed - yes I only edit the forward and backward passes, but I'm assuming the wrapping mechanisms are causing issues ie PolicyAndValueWrapper maybe - the best way is to inject nan gradient checks throughout the entire code base to pinpoint the issue.

If I had to guess, the Cross Entropy Loss part is causing issues, since I manually shift the labels and append stuff, so maybe it might be causing issues.

I also turned off unsloth gradient checkpointing, and it still doesnt work

lapp0 commented 5 months ago

@danielhanchen one other thing that isn't tried / tested by the Unsloth community is interleaving training and generating, which this script does. I have a feeling that is a possible culprit. I'll experiment with training only using pre-generated samples when I get a chance.

Also I don't think the PolicyAndValueWrapper is the issue, I have another variant without any value model. It has mostly similar code and has nan grad_norm.

For nan gradient checks, I already am running with

with torch.autograd.detect_anomaly():
    trainer.train()

Do you know a good way to inject hooks which apply more extensive and detailed nan checks?

lapp0 commented 5 months ago

Edit: I was mistaken about the source of the problem. However I did discover that if my per_device_batch_size is 1, I don't get the error. I'm not sure what the reason might be.

danielhanchen commented 5 months ago

@lapp0 Apologies on the delay! Ok weird so it might be something related to batching. Weird. Do you know if generation also uses per_device_batch_size internally?

lapp0 commented 5 months ago

@danielhanchen I'm pretty confident that the issue relates to padding now. The error doesn't occur with batch size N > 1 if the sequences are the same length (no padding). The code sets the logits indices which aren't attended to as an illegal value.

INVALID_LOGPROB = 1.0

...

    def forward(self, model, query_responses):
        attention_mask = query_responses != self.tokenizer.pad_token_id
        position_ids = attention_mask.cumsum(1) - attention_mask.long()
        input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
        return model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
            output_hidden_states=True,
            use_cache=False,
        )

...

        output = self.forward(self.model, query_responses)
        output.logits.mean().backward(retain_graph=True)
        logits = output.logits[:, context_length - 1: -1]
        logits /= self.args.temperature + 1e-7
        new_all_logprobs = F.log_softmax(logits, dim=-1)
        new_logprobs = torch.gather(new_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1)
        new_logprobs = torch.masked_fill(
            new_logprobs, padding_mask, INVALID_LOGPROB
        )

I'm wondering whether Unsloth includes logits which aren't included by the attention mask in the backward pass?

I'll do some more experimentation.

lapp0 commented 5 months ago

I found the issue and created a reproduction script! https://github.com/unslothai/unsloth/issues/533

danielhanchen commented 5 months ago

Thanks for the investigation - ill take a look!