axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.55k stars 816 forks source link

Can't preprocess dataset using meta-llama/Meta-Llama-3.1-8B model #1800

Closed ohmeow closed 1 month ago

ohmeow commented 1 month ago

Please check that this issue hasn't been reported before.

Expected Behavior

I expected to have a pre-processed dataset after running python -m axolotl.cli.preprocess

Current behaviour

I get this error:

NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

Full trace:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/wgilliam/development/projects/axolotl/src/axolotl/cli/preprocess.py", line 96, in <module>
    fire.Fire(do_cli)
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/development/projects/axolotl/src/axolotl/cli/preprocess.py", line 85, in do_cli
    AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3788, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1068, in __init__
    self.model = LlamaModel(config)
                 ^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 845, in __init__
    [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 845, in <listcomp>
    [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 632, in __init__
    self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 306, in __init__
    self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 119, in __init__
    inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/modeling_rope_utils.py", line 330, in _compute_llama3_parameters
    if wavelen < high_freq_wavelen:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/torch/utils/_device.py", line 78, in __torch_function__
    return func(*args, **kwargs)

Steps to reproduce

Run > python -m axolotl.cli.preprocess

Config yaml

base_model: "meta-llama/Meta-Llama-3.1-8B"
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

seed: 9
data_seed: 9

hub_model_id:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

datasets:
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.05
output_dir: outputs

sequence_len: 3072
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

adapter: qlora 
load_in_8bit: false 
load_in_4bit: true
strict: false

lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

# NOTE: If you add token you will need to save thse lora modules
# lora_modules_to_save:
#   - embed_tokens
#   - lm_head

gradient_accumulation_steps: 4 # 8
micro_batch_size: 2 #1
eval_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 3e-5 #0.0002
# max_grad_norm: 1.0
# adam_beta2: 0.95
# adam_epsilon: 0.00001
# save_total_limit: 12

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

# max_steps: 100
# warmup_steps: 100 #10
warmup_ratio: 0.2
evals_per_epoch: 4 #8
# eval_steps: 10
eval_table_size:
eval_max_new_tokens: 512 #128
save_total_limit: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|end_of_text|>

# NOTE: If you add tokens, in addition to updating the lora config (see above), you'll
#       likely need to reduce your overall batch size if training on a GPU poor rig
# tokens:
#   - <function-definitions>
#   - </function-definitions>
#   - <function-thoughts>
#   - </function-thoughts>
#   - <function-calls>
#   - </function-calls>
#   - function_call

save_safetensors: true


### Possible solution

N/A

### Which Operating Systems are you using?

- [X] Linux
- [ ] macOS
- [ ] Windows

### Python Version

3.11

### axolotl branch-commit

main

### Acknowledgements

- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
winglian commented 1 month ago

I don't see a dataset in your configuration YAML. Did you redact it? Can you provide some info on the dataset/prompt type you're trying to preprocess?

ohmeow commented 1 month ago

UPDATE: Looks like its something with Transformers with the recommendation being to install from github directly. See: https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/discussions/54

Sorry, I'm adding that in dynamically (running in jupyter). I'm using the template free format which is working fine in old Llama3 ...

# axo_config_fpath = "configs/axolotl_configs/llama3-8b-qlora.yaml"
axo_config_fpath = "configs/axolotl_configs/llama3.1-8b-qlora.yaml"

train_data_fpath = "data/train_reviewed_template_free_1000.jsonl"
train_data_config = str(f'[{{"path": "{train_data_fpath}", "type":"input_output"}}]')

python -m axolotl.cli.preprocess {axo_config_fpath} --datasets '{train_data_config}'
ohmeow commented 1 month ago

Closing this out. Can verify that pip install from the transformers main branch provides the necessary fix.