pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4k stars 361 forks source link

Error when running packed dataset with compile=True on nightlies #1649

Open ebsmothers opened 1 day ago

ebsmothers commented 1 day ago

I had run pretty much this exact command a couple weeks ago when doing benchmarking but now it is failing with a stride mismatch error. Creating this issue so others can take a look as well. Repro and error message is given below

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device optimizer=bitsandbytes.optim.AdamW8bit \
~optimizer.fused ~optimizer.weight_decay log_peak_memory_stats=True dataset.packed=True compile=True \
tokenizer.max_seq_len=2048 enable_activation_checkpointing=False
...
assert_size_stride(select_1, (2, 2048, 1, 64), (262144, 128, 128, 2))
AssertionError: expected size 2==2, stride 131072==262144 at dim=0; expected size 2048==2048, stride 64==128 at dim=1; expected size 64==64, stride 1==2 at dim=3
SalmanMohammadi commented 1 day ago

Running into this trying to repro

INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama-3-8B-Instruct/
  recipe_checkpoint: null
compile: true
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: true
device: cuda
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: true
epochs: 1
gradient_accumulation_steps: 16
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
low_cpu_ram: false
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/qlora_finetune_output/
model:
  _component_: torchtune.models.llama3.qlora_llama3_8b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - k_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 8
optimizer:
  _component_: bitsandbytes.optim.AdamW8bit
  lr: 0.0003
output_dir: /tmp/qlora_finetune_output/
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/qlora_finetune_output//profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 5
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: 2048
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3803034570. Local seed is seed + rank = 3803034570 + 0
Writing logs to /tmp/qlora_finetune_output/log_1727042869.txt
INFO:torchtune.utils._logging:Compiling model layers with torch.compile...
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 6.32 GiB
        GPU peak memory reserved: 6.49 GiB
        GPU peak memory active: 6.32 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Compiling loss with torch.compile...
INFO:torchtune.utils._logging:Loss is initialized.
README.md: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 39.8MB/s]
alpaca_data_cleaned.json: 100%|███████████████████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:02<00:00, 15.8MB/s]
Generating train split: 100%|█████████████████████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 116158.04 examples/s]
Packing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 51760/51760 [00:23<00:00, 2176.10it/s]
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
INFO:torchtune.utils._logging:NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration.
  0%|                                                                                                                          | 0/166 [00:00<?, ?it/s]DEBUG:torchtune.utils._logging:Using flex attention for attention computation since a BlockMask was passed in.
Traceback (most recent call last):
  File "/usr/local/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/usr/local/lib/python3.11/dist-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/usr/local/lib/python3.11/dist-packages/torchtune/_cli/run.py", line 185, in _run_cmd
    self._run_single_device(args)
  File "/usr/local/lib/python3.11/dist-packages/torchtune/_cli/run.py", line 94, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.11/dist-packages/recipes/lora_finetune_single_device.py", line 794, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/recipes/lora_finetune_single_device.py", line 789, in recipe_main
    recipe.train()
  File "/usr/local/lib/python3.11/dist-packages/recipes/lora_finetune_single_device.py", line 699, in train
    loss.backward()
  File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2048, in backward
    out = call_compiled_backward()
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1954, in call_compiled_backward
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 51, in _wrapped_bw_compiler
    return disable(disable(bw_compiler)(*args, **kwargs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1507, in bw_compiler
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 480, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 666, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 575, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 883, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1948, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1874, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1902, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 2949, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/jr/cjrpoapjiov7wytf573hrnb7ex2pzqtm2lhdxvtsrejwhyqb44qg.py", line 837, in <module>
    async_compile.wait(globals())
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/async_compile.py", line 286, in wait
    scope[key] = result.result()
                 ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 3417, in result
    self.kernel.precompile()
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 250, in precompile
    raise e
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 453, in _precompile_config
    binary._init_handles()
  File "/usr/local/lib/python3.11/dist-packages/triton/compiler/compiler.py", line 374, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 131074, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
  0%|                                                                                                                          | 0/166 [00:54<?, ?it/s]

on a 4090

ebsmothers commented 13 hours ago

@SalmanMohammadi at least on my end this was user error.. now that we've enabled activation offloading I should have set both enable_activation_checkpointing=False and enable_activation_offloading=False. But given that you're running into an issue on a 4090, maybe we should leave it open for debugging that.

felipemello1 commented 8 hours ago

@ebsmothers , what do you mean? enable_activation_offloading was set to True? AFAIK, this shouldnt error