pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.24k stars 416 forks source link

[Question] QLora on MPS? #1942

Closed austinmw closed 1 hour ago

austinmw commented 2 hours ago

Hi, I attempted to fine-tune Meta-Llama-3.1-8B-Instruct using llama3_1/8B_qlora_single_device with device set to mps on an M2 MBP, but hit the following error:

RuntimeError: Only float and half are supported

Is this currently supported?

Full trace:

tune run lora_finetune_single_device --config ./my_custom_config.yaml
W1102 09:40:41.997000 22803 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
device: mps
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 16
log_every_n_steps: 1
log_peak_memory_stats: false
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
low_cpu_ram: false
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/qlora_finetune_output/
model:
  _component_: torchtune.models.llama3_1.qlora_llama3_1_8b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - k_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 8
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/qlora_finetune_output/
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/qlora_finetune_output//profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 5
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2927239804. Local seed is seed + rank = 2927239804 + 0
Writing logs to /tmp/qlora_finetune_output/log_1730554842.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
README.md: 100%|████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 15.1MB/s]
alpaca_data_cleaned.json: 100%|█████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:00<00:00, 74.8MB/s]
Generating train split: 100%|███████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 205775.61 examples/s]
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
  0%|                                                                                                           | 0/1617 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/austinwelch/mambaforge/envs/py312/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 196, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 287, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 793, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 788, in recipe_main
    recipe.train()
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 707, in train
    self._optimizer.step()
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 694, in _fused_adamw
    torch._fused_adamw_(
RuntimeError: Only float and half are supported
  0%|                                                                                                           | 0/1617 [02:56<?, ?it/s]
SalmanMohammadi commented 2 hours ago

Hi, I attempted to fine-tune Meta-Llama-3.1-8B-Instruct using llama3_1/8B_qlora_single_device with device set to mps on an M2 MBP, but hit the following error:

RuntimeError: Only float and half are supported

Is this currently supported?

Full trace:

tune run lora_finetune_single_device --config ./my_custom_config.yaml
W1102 09:40:41.997000 22803 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
device: mps
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 16
log_every_n_steps: 1
log_peak_memory_stats: false
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
low_cpu_ram: false
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/qlora_finetune_output/
model:
  _component_: torchtune.models.llama3_1.qlora_llama3_1_8b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - k_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 8
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/qlora_finetune_output/
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/qlora_finetune_output//profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 5
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2927239804. Local seed is seed + rank = 2927239804 + 0
Writing logs to /tmp/qlora_finetune_output/log_1730554842.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
README.md: 100%|████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 15.1MB/s]
alpaca_data_cleaned.json: 100%|█████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:00<00:00, 74.8MB/s]
Generating train split: 100%|███████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 205775.61 examples/s]
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
  0%|                                                                                                           | 0/1617 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/austinwelch/mambaforge/envs/py312/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 196, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 287, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 793, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 788, in recipe_main
    recipe.train()
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 707, in train
    self._optimizer.step()
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 694, in _fused_adamw
    torch._fused_adamw_(
RuntimeError: Only float and half are supported
  0%|                                                                                                           | 0/1617 [02:56<?, ?it/s]

Hey @austinmw! I think fused=True for the AdamW implementation doesn't support bf16 kernels on MPS. Try set fused=False in your config, or pass in optimizer.fused=False through the CLI.

austinmw commented 1 hour ago

Thanks!