Closed austinmw closed 1 hour ago
Hi, I attempted to fine-tune
Meta-Llama-3.1-8B-Instruct
usingllama3_1/8B_qlora_single_device
with device set tomps
on an M2 MBP, but hit the following error:RuntimeError: Only float and half are supported
Is this currently supported?
Full trace:
tune run lora_finetune_single_device --config ./my_custom_config.yaml W1102 09:40:41.997000 22803 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs. INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config: batch_size: 2 checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ checkpoint_files: - model-00001-of-00004.safetensors - model-00002-of-00004.safetensors - model-00003-of-00004.safetensors - model-00004-of-00004.safetensors model_type: LLAMA3 output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ recipe_checkpoint: null compile: false dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset device: mps dtype: bf16 enable_activation_checkpointing: true enable_activation_offloading: false epochs: 1 gradient_accumulation_steps: 16 log_every_n_steps: 1 log_peak_memory_stats: false loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss low_cpu_ram: false lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 max_steps_per_epoch: null metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/qlora_finetune_output/ model: _component_: torchtune.models.llama3_1.qlora_llama3_1_8b apply_lora_to_mlp: true apply_lora_to_output: false lora_alpha: 16 lora_attn_modules: - q_proj - v_proj - k_proj - output_proj lora_dropout: 0.0 lora_rank: 8 optimizer: _component_: torch.optim.AdamW fused: true lr: 0.0003 weight_decay: 0.01 output_dir: /tmp/qlora_finetune_output/ profiler: _component_: torchtune.training.setup_torch_profiler active_steps: 2 cpu: true cuda: true enabled: false num_cycles: 1 output_dir: /tmp/qlora_finetune_output//profiling_outputs profile_memory: false record_shapes: true wait_steps: 5 warmup_steps: 5 with_flops: false with_stack: false resume_from_checkpoint: false save_adapter_weights_only: false seed: null shuffle: true tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer max_seq_len: null path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2927239804. Local seed is seed + rank = 2927239804 + 0 Writing logs to /tmp/qlora_finetune_output/log_1730554842.txt INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16. INFO:torchtune.utils._logging:Tokenizer is initialized from file. INFO:torchtune.utils._logging:Optimizer and loss are initialized. INFO:torchtune.utils._logging:Loss is initialized. README.md: 100%|████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 15.1MB/s] alpaca_data_cleaned.json: 100%|█████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:00<00:00, 74.8MB/s] Generating train split: 100%|███████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 205775.61 examples/s] INFO:torchtune.utils._logging:Dataset and Sampler are initialized. INFO:torchtune.utils._logging:Learning rate scheduler is initialized. WARNING:torchtune.utils._logging: Profiling disabled. INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False} 0%| | 0/1617 [00:00<?, ?it/s]Traceback (most recent call last): File "/Users/austinwelch/mambaforge/envs/py312/bin/tune", line 8, in <module> sys.exit(main()) ^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 49, in main parser.run(args) File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/tune.py", line 43, in run args.func(args) File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 196, in _run_cmd self._run_single_device(args, is_builtin=is_builtin) File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/_cli/run.py", line 102, in _run_single_device runpy.run_path(str(args.recipe), run_name="__main__") File "<frozen runpy>", line 287, in run_path File "<frozen runpy>", line 98, in _run_module_code File "<frozen runpy>", line 88, in _run_code File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 793, in <module> sys.exit(recipe_main()) ^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torchtune/config/_parse.py", line 99, in wrapper sys.exit(recipe_main(conf)) ^^^^^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 788, in recipe_main recipe.train() File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/recipes/lora_finetune_single_device.py", line 707, in train self._optimizer.step() File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper return func.__get__(opt, opt.__class__)(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 487, in wrapper out = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 91, in _use_grad ret = func(self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 220, in step adamw( File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 782, in adamw func( File "/Users/austinwelch/mambaforge/envs/py312/lib/python3.12/site-packages/torch/optim/adamw.py", line 694, in _fused_adamw torch._fused_adamw_( RuntimeError: Only float and half are supported 0%| | 0/1617 [02:56<?, ?it/s]
Hey @austinmw! I think fused=True
for the AdamW implementation doesn't support bf16 kernels on MPS. Try set fused=False
in your config, or pass in optimizer.fused=False
through the CLI.
Thanks!
Hi, I attempted to fine-tune
Meta-Llama-3.1-8B-Instruct
usingllama3_1/8B_qlora_single_device
with device set tomps
on an M2 MBP, but hit the following error:Is this currently supported?
Full trace: