OpenAccess-AI-Collective / axolotl

Go ahead and axolotl questions
https://openaccess-ai-collective.github.io/axolotl/
Apache License 2.0
6.8k stars 748 forks source link

Llama3-8b: LlamaForCausalLM.forward() got an unexpected keyword argument 'length' #1700

Open DMR92 opened 3 weeks ago

DMR92 commented 3 weeks ago

Please check that this issue hasn't been reported before.

Expected Behavior

I expect the training run to finish and save the weights for the finetuned model

Current behaviour

Hi,

after starting the training process for a Llama-3 8b LoRa finetuning on Jarvis using Axolotl I get the following error: TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'length'.

Steps to reproduce

I picked 1 x RTX6000Ada on Jarvis Labs with the Axolotl template. After starting a jupyter notebook I slightly modified the default config for lora-8b. In particular, I only added hub_model_id, wandb_entity, wandb_project and eval_sample_packing: false to the default example config, but left the rest unchanged.

Config yaml

hub_model_id: d4niel92/llama3-8B_alpaca_2k_lora

wandb_entity: d4nielmeyer
wandb_project: llama3-8b_alpaca_2k_lora

base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca

dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
  - embed_tokens
  - lm_head

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
   pad_token: <|end_of_text|>

Possible solution

Similar key errors related to this forward method were solved by setting flash_attention to false, but that didn't solve it in my case.

I also ran a training with exactly the same dataset as input for tiny-llama, and everything worked out perfectly. Checking the input data if a sample contains a string "length" returned no result.

Any ideas what might cause the issue and how I can solve it?

Thank you 🙏

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main/851ccb1

Acknowledgements

alugowski commented 3 weeks ago

Seeing the same issue with a fresh pip install on Ubuntu using H100.

alugowski commented 3 weeks ago

My traceback:

[rank6]: Traceback (most recent call last):
[rank6]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank6]:     return _run_code(code, main_globals, None,
[rank6]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank6]:     exec(code, run_globals)
[rank6]:   File "/home/adam/train/axolotl/src/axolotl/cli/train.py", line 70, in <module>
[rank6]:     fire.Fire(do_cli)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank6]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank6]:     component, remaining_args = _CallAndUpdateTrace(
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank6]:     component = fn(*varargs, **kwargs)
[rank6]:   File "/home/adam/train/axolotl/src/axolotl/cli/train.py", line 38, in do_cli
[rank6]:     return do_train(parsed_cfg, parsed_cli_args)
[rank6]:   File "/home/adam/train/axolotl/src/axolotl/cli/train.py", line 66, in do_train
[rank6]:     return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
[rank6]:   File "/home/adam/train/axolotl/src/axolotl/train.py", line 170, in train
[rank6]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank6]:     return inner_training_loop(
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 2291, in _inner_training_loop
[rank6]:     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 2721, in _maybe_log_save_evaluate
[rank6]:     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3572, in evaluate
[rank6]:     output = eval_loop(
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3757, in evaluation_loop
[rank6]:     loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3971, in prediction_step
[rank6]:     loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
[rank6]:   File "/home/adam/train/axolotl/src/axolotl/core/trainer_builder.py", line 537, in compute_loss
[rank6]:     return super().compute_loss(model, inputs, return_outputs=return_outputs)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3264, in compute_loss
[rank6]:     outputs = model(**inputs)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank6]:     return self._call_impl(*args, **kwargs)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank6]:     return forward_call(*args, **kwargs)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
[rank6]:     return model_forward(*args, **kwargs)
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
[rank6]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank6]:   File "/home/adam/venvs/axolotl/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
[rank6]:     return func(*args, **kwargs)
[rank6]: TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'length'

Using torch 2.3.0

DMR92 commented 3 weeks ago

It might be related to the validation part. Setting val_set_size: 0.0 works to avoid the error notification. Still I would like to understand what's happening here and if finishing the training job without validation set has any effect / comes at a cost?

alugowski commented 3 weeks ago

I narrowed down the timeframe of the bug.

The docker image winglian/axolotl:main-20240608-py3.10-cu118-2.1.2 works, but the next day's image, winglian/axolotl:main-20240609-py3.10-cu118-2.1.2 shows the error.

So use a version of Axolotl from before June 8.

ann-brown commented 3 weeks ago

Just going to quickly note a replication on a(n unusual) Mixtral architecture, including that removing val_set_size allows the training to run.

alugowski commented 3 weeks ago

It looks like this PR caused it: https://github.com/OpenAccess-AI-Collective/axolotl/pull/1695 (commit 18cabc0c461c9178c90fcb080e40e7daa9c6c6f8)

Using the commit right before works:

git checkout ed8ef6537182fe516a2940355f7e34a397b22fdc
ganler commented 2 weeks ago

Met the same issue when running an evaluation.

mkaesler44 commented 2 weeks ago

Currently get the same issue when running examples/tiny-llama/lora.yml with the latest image pulled on runpod