axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.97k stars 879 forks source link

ORPO results in `Cannot flatten integer dtype tensors` #1838

Open maziyarpanahi opened 3 months ago

maziyarpanahi commented 3 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

I expect the ORPO works properly with FSDP and DeepSpeed on Qwen2 models.

Current behaviour

Currently, it's not possible to use ORPO via FSDP or DeepSpeed. It results in

Possible issues:

 File "/workspace/axolotl/src/axolotl/cli/train.py", line 67, in do_train
          [Previous line repeated 2 more times]
component_trace = _Fire(component, args, parsed_flag_args, context, name)wrapped_child, num_wrapped_params = _recursive_wrap(  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, 
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1181, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1477, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 463, in __init__
    _auto_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 484, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 487, in __init__
    _init_param_handle_from_module(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 519, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 537, in __init__
    self._init_flat_param_and_metadata(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 585, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors

Steps to reproduce

rm -rf axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl && \
cd axolotl && \
git checkout 608a2f3 && \
pip install setuptools && \
pip install -e .[flash-attn,deepspeed] && \
cd ..

Config yaml

base_model: arcee-ai/Arcee-Nova
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

save_safetensors: true

rl: orpo
orpo_alpha: 0.1
chat_template: chatml
datasets:
  - path: mlabonne/orpo-dpo-mix-40k
    type: chat_template.argilla
    chat_template: chatml

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./models/Arcee-Nova-ORPO-v0.1

adapter: qlora
lora_model_dir:

sequence_len: 1800
sample_packing: false
pad_to_sequence_len: false

adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  - gate_proj
  - up_proj
  - down_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false

bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 50
evals_per_epoch: 1
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 100
debug:
weight_decay: 0.05
fsdp:
   - full_shard
   - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
  pad_token: "<|endoftext|>"
  eos_token: "<|im_end|>"

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

608a2f3

Acknowledgements

winglian commented 3 months ago

I think we'll want to change from our orpo implementation to the trl ORPOTrainer implementation.

maziyarpanahi commented 3 months ago

This is interesting! Would love to help testing it if you have any work in progress for ORPOTrainer?

maziyarpanahi commented 2 months ago

Hi @winglian any updates? tell me if you need me to test anything?

winglian commented 3 days ago

oh, hmm, we already use the ORPOTrainer (https://github.com/axolotl-ai-cloud/axolotl/pull/1551/files), will need to dig into this a bit deeper

maziyarpanahi commented 2 days ago

oh, hmm, we already use the ORPOTrainer (https://github.com/axolotl-ai-cloud/axolotl/pull/1551/files), will need to dig into this a bit deeper

Thanks a lot @winglian - is this something new? Should I try again? (my message is a few months old)