pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.94k stars 357 forks source link

[bug] Compile error on stable PyTorch #1498

Closed ebsmothers closed 1 week ago

ebsmothers commented 1 week ago

Repro:

tune run full_finetune_single_device --config llama2/7B_full_low_memory \
optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False \
compile=True metric_logger=torchtune.training.metric_logging.WandBLogger \
log_peak_memory_stats=True batch_size=16 log_every_n_steps=10 epochs=1 seed=2024

results in error

  File "/home/ebs/.conda/envs/repro-compile-error/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1822, in validate
    raise AssertionError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.clone.default(tensor([...], size=(16,), dtype=torch.uint8), memory_format=torch.contiguous_format)

Full stack trace

ebsmothers commented 1 week ago

cc @gau-nernst @felipemello1

felipemello1 commented 1 week ago

@yf225 have you seen this error before? It doesnt happen with torch nightlies or in the CI when we use backend="aot_eager"

we are compiling it per layer here: https://github.com/pytorch/torchtune/blob/82c232d0679ddef3fc419cdc18af758b98b4da05/recipes/full_finetune_single_device.py#L364

yf225 commented 1 week ago

Wonder is it on a specific PyTorch version?

I tried to repro it on latest TorchTune main and PyTorch main, but couldn't repro this error 🤔

felipemello1 commented 1 week ago

this is pytorch 2.4. Sorry i didnt make it clear @yf225

yf225 commented 1 week ago

Hmm wonder would it be okay to require running on PyTorch nightly? I might need to look into this, but my worry is that if there is a fix we won't be able to retroactively add it into PyTorch 2.4 release 😞

gau-nernst commented 1 week ago

Perhaps we can investigate if the old ways of doing compile (compile whole model model.compile() and compile loss step torch.compile(_loss_step)) work for pytorch 2.4? For compile loss step, I think last time I also only tested it with torch nightly...

Not supporting latest stable pytorch seems like a big deal. At least from my experience, apart from using stable versions for use cases requiring stability, stable versions are required to do reproducible experiments, since specific nightly versions will disappear.

ebsmothers commented 1 week ago

Yeah we can definitely just version gate if the new ways we're compiling break things on 2.4. A bit of a UX hit but I agree that we do want to always at least support the latest stable version. Also we missed this because the compile backend for our tests is aot_eager (ref). Big thanks to @gau-nernst for catching both of these issues; I am (slowly) chipping away at debugging the CI coverage one on #1508