LAION-AI / Open-Assistant

OpenAssistant is a chat-based assistant that understands tasks, can interact with third-party systems, and retrieve information dynamically to do so.
https://open-assistant.io
Apache License 2.0
37.04k stars 3.23k forks source link

Supervised fine-tuning: "RuntimeError: expected scalar type Half but found Float" during evaluation #3569

Open theophilegervet opened 1 year ago

theophilegervet commented 1 year ago

While running supervised fine-tuning with

python trainer_sft.py --configs lora-llama-13b webgpt_dataset_only

and the following config

lora-llama-13b:
  dtype: fp16
  log_dir: "llama_lora_log_13b"
  learning_rate: 5e-5
  model_name: openlm-research/open_llama_13b
  output_dir: llama_model_13b_lora
  weight_decay: 0.0
  max_length: 2048
  warmup_steps: 300
  gradient_checkpointing: true
  gradient_accumulation_steps: 1
  per_device_train_batch_size: 6
  per_device_eval_batch_size: 1
  eval_steps: 500
  num_train_epochs: 12
  save_total_limit: 2
  save_strategy: epoch
  use_flash_attention: True
  residual_dropout: 0.0
  deepspeed_config: configs/zero_config.json
  peft_model: true
  peft_type: "lora"
  use_custom_sampler: true

training runs fine but evaluation raises the following error (at the first eval step):

File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 2234, in _maybe_log_save_evaluate
metrics = self.evaluate(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 2939, in evaluate
output = eval_loop(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 3120, in evaluation_loop
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 107, in prediction_step
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 87, in _compute_loss
outputs = model(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/peft/peft_model.py", line 530, in forward
return self.base_model(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
outputs = self.model(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
layer_outputs = decoder_layer(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/Open-Assistant/model/model_training/models/patching_llama.py", line 28, in llama_forward_with_flash_attn
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/peft/tuners/lora.py", line 350, in forward
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Half but found Float

with environment

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
accelerate                0.21.0                   pypi_0    pypi
aiohttp                   3.8.4                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
beautifulsoup4            4.12.2                   pypi_0    pypi
bitsandbytes              0.40.0.post4             pypi_0    pypi
brotli                    1.0.9                    pypi_0    pypi
bzip2                     1.0.8                h7b6447c_0  
ca-certificates           2023.5.7             hbcca054_0    conda-forge
cattrs                    23.1.2                   pypi_0    pypi
certifi                   2023.5.7                 pypi_0    pypi
charset-normalizer        3.2.0                    pypi_0    pypi
click                     8.1.5                    pypi_0    pypi
cmake                     3.26.4                   pypi_0    pypi
cudatoolkit-dev           11.7.0               h1de0b5d_6    conda-forge
datasets                  2.13.1                   pypi_0    pypi
deepspeed                 0.9.5                    pypi_0    pypi
dill                      0.3.6                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
einops                    0.6.1                    pypi_0    pypi
evaluate                  0.4.0                    pypi_0    pypi
exceptiongroup            1.1.2                    pypi_0    pypi
fastlangid                1.0.11                   pypi_0    pypi
fasttext                  0.9.2                    pypi_0    pypi
filelock                  3.12.2                   pypi_0    pypi
flash-attn                1.0.8                    pypi_0    pypi
frozenlist                1.4.0                    pypi_0    pypi
fsspec                    2023.6.0                 pypi_0    pypi
gdown                     4.7.1                    pypi_0    pypi
gitdb                     4.0.10                   pypi_0    pypi
gitpython                 3.1.32                   pypi_0    pypi
grpcio                    1.51.3                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
huggingface-hub           0.16.4                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
inflate64                 0.3.1                    pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
joblib                    1.3.1                    pypi_0    pypi
jsonschema                4.18.3                   pypi_0    pypi
jsonschema-specifications 2023.6.1                 pypi_0    pypi
langcodes                 3.3.0                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
lit                       16.0.6                   pypi_0    pypi
loguru                    0.6.0                    pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
model-training            1.0.0                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.0.5                    pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
multiprocess              0.70.14                  pypi_0    pypi
multivolumefile           0.2.3                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.1                      pypi_0    pypi
ninja                     1.11.1                   pypi_0    pypi
nltk                      3.8.1                    pypi_0    pypi
numpy                     1.25.1                   pypi_0    pypi
nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
nvidia-cuda-cupti-cu11    11.7.101                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
oasst-data                1.0.0                    pypi_0    pypi
openssl                   3.0.9                h7f8727e_0  
packaging                 23.1                     pypi_0    pypi
pandas                    2.0.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
peft                      0.2.0                    pypi_0    pypi
pip                       23.1.2          py310h06a4308_0  
protobuf                  4.23.4                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
py7zr                     0.20.5                   pypi_0    pypi
pyarrow                   12.0.1                   pypi_0    pypi
pybcj                     1.0.1                    pypi_0    pypi
pybind11                  2.10.4                   pypi_0    pypi
pycryptodomex             3.18.0                   pypi_0    pypi
pydantic                  1.10.7                   pypi_0    pypi
pygments                  2.15.1                   pypi_0    pypi
pyppmd                    1.0.0                    pypi_0    pypi
pysocks                   1.7.1                    pypi_0    pypi
python                    3.10.12              h955ad1f_0  
python-dateutil           2.8.2                    pypi_0    pypi
python-rapidjson          1.10                     pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
pyzstd                    0.15.9                   pypi_0    pypi
ray                       2.5.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
referencing               0.29.1                   pypi_0    pypi
regex                     2023.6.3                 pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
responses                 0.18.0                   pypi_0    pypi
rich                      13.4.2                   pypi_0    pypi
rpds-py                   0.8.10                   pypi_0    pypi
scikit-learn              1.3.0                    pypi_0    pypi
scipy                     1.11.1                   pypi_0    pypi
sentencepiece             0.1.99                   pypi_0    pypi
sentry-sdk                1.28.1                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                67.8.0          py310h06a4308_0  
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.0                    pypi_0    pypi
soupsieve                 2.4.1                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
sympy                     1.12                     pypi_0    pypi
tabulate                  0.9.0                    pypi_0    pypi
texttable                 1.6.7                    pypi_0    pypi
threadpoolctl             3.2.0                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.13.3                   pypi_0    pypi
torch                     2.0.1                    pypi_0    pypi
torchtyping               0.1.4                    pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
transformers              4.28.0.dev0              pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
tritonclient              2.35.0                   pypi_0    pypi
trlx                      0.7.0                    pypi_0    pypi
typeguard                 4.0.0                    pypi_0    pypi
typing-extensions         4.7.1                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
urllib3                   2.0.3                    pypi_0    pypi
wandb                     0.15.5                   pypi_0    pypi
wheel                     0.38.4          py310h06a4308_0  
xxhash                    3.2.0                    pypi_0    pypi
xz                        5.4.2                h5eee18b_0  
yarl                      1.9.2                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  

Any idea what could be causing this and how to fix it?

andreaskoepf commented 1 year ago

It's interesting that it occurs during eval. I asked @jordiclive and he said that he has trained several llama lora models in fp16 including 7B. If you want to debug this issue and investigate the cause you could set eval_steps in the configuration to 1.

jordiclive commented 1 year ago

@theophilegervet. Yes that is strange, I didn't encounter this error when training 7B decapoda-research/llama-7b-hf or 13b instead of openlm-research/open_llama_13b with fp16.

If you set eval_steps to 1 and change the dataset does it still occur?

Perhaps also trying with peft==0.3.0

theophilegervet commented 1 year ago

Thank you @jordiclive! peft==0.3.0 fixes the issue with lora-llama-13b and openlm-research/open_llama_13b.

I still have the issue with llama-7b though. decapoda-research/llama-7b-hf gives ValueError: Tokenizer class LLaMATokenizer does not exist or is not currently imported. So I'm using huggyllama/llama-7b.

I get the following error:

Traceback (most recent call last):
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 477, in <module>
    main()
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 471, in main
    trainer.train(resume_from_checkpoint=training_conf.resume_from_checkpoint)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1532, in train
    return inner_training_loop(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1863, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/accelerate/accelerator.py", line 1925, in clip_grad_norm_
    self.unscale_gradients()
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/accelerate/accelerator.py", line 1888, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

This happens with use_flash_attention: true or use_flash_attention: false.

I think I need to address this issue because I'm trying to train a reward model with

python trainer_rm.py --configs defaults_rm oasst-rm-1-pythia-6.9b --wandb-entity tgervet

and get the same error there

Traceback (most recent call last):
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_rm.py", line 334, in <module>
    main()
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_rm.py", line 328, in main
    trainer.train(resume_from_checkpoint=training_conf.resume_from_checkpoint)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1639, in train
    return inner_training_loop(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1939, in _inner_training_loop
    self.scaler.unscale_(self.optimizer)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

Replacing dtype: fp16 by dtype: fp32 gives an OOM error.

Could you please share your environment so I can debug the delta?

theophilegervet commented 1 year ago

I've tried following the updated environment you provided

bitsandbytes==0.41.0
deepspeed==0.10.0
peft==0.4.0
transformers==4.31.0
flash-attn==2.0.0.post1

but still hit the same issue

andreaskoepf commented 1 year ago

I saw a similar error to the one you described without deepspeed ..to run with deepspeed you need to replace python on the command line with deepspeed, e.g. deepspeed trainer_sft.py --configs rope_scaling_test --deepspeed .. could you please try this?

theophilegervet commented 1 year ago

With the following deepspeed command

deepspeed trainer_sft.py --configs llama-7b webgpt_dataset_only --deepspeed

I get an OOM error on a 40GB A100 (even with batch size 1 and sequence length 128):

Traceback (most recent call last):
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 477, in <module>
    main()
  File "/home/tgervet/Open-Assistant/model/model_training/trainer_sft.py", line 471, in main
    trainer.train(resume_from_checkpoint=training_conf.resume_from_checkpoint)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1532, in train
    return inner_training_loop(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/transformers/trainer.py", line 1655, in _inner_training_loop
    model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/accelerate/accelerator.py", line 1198, in prepare
    result = self._prepare_deepspeed(*args)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/accelerate/accelerator.py", line 1537, in _prepare_deepspeed
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/deepspeed/__init__.py", line 171, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 310, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1209, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1444, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer(
  File "/home/tgervet/miniconda3/envs/open-assistant/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 346, in __init__
    self.device).clone().float().detach())
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 25.10 GiB (GPU 0; 39.42 GiB total capacity; 25.13 GiB already allocated; 13.67 GiB free; 25.14 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It seems like deepspeed is trying to unscale to float32? This might explain I was getting a float16 error without deepspeed?

@jordiclive Were you training with or without deepspeed?

@andreaskoepf @jordiclive I'm not sure how to proceed. Supervised fine-tuning of the lora-llama-13b model works fine for me on a 40GB A100. The float16 error only appears for non-LORA models. Maybe we could set up reward model training with LORA too?