Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors #673

Closed wprazuch closed 3 hours ago

wprazuch commented 6 days ago

🐛 Bug

There is unsupported error when running models:

For compile: thunder_inductor_cat_cudnn and thunder_cudnn, both fsdp zero3. Running on 8 nodes, 8 gpus each.

To Reproduce

Steps to reproduce the behavior:

mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_IMAGE:pjnl-20240621

Run in the container:

torchrun --nproc-per-node=8 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Mixtral-8x7B-v0.1 --compile thunder_inductor_cat_cudnn --distributed_mode fsdp --shard_mode zero3 

Expected behavior

The model should run or we should get OOM error.

Environment

As in the Docker image

Additional context

Traceback:

An error occurred: NotImplementedError – aten::nonzero: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/docs/main/notes/custom_operators.html
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 639, in <module>
[rank0]:     CLI(benchmark_main)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 96, in CLI
[rank0]:     return _run_component(components, cfg_init)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 196, in _run_component
[rank0]:     return component(**cfg)
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 584, in benchmark_main
[rank0]:     benchmark.train()
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 448, in train
[rank0]:     self.calculate_model_flops()
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 440, in calculate_model_flops
[rank0]:     self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/fabric/utilities/throughput.py", line 304, in measure_flops
[rank0]:     loss_fn(forward_fn()).backward()
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 436, in <lambda>
[rank0]:     model_fwd = lambda: meta_model(x)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1657, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1722, in _call_impl
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/litgpt/model.py", line 94, in forward
[rank0]:     x = block(x, cos, sin, mask, input_pos)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1657, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1722, in _call_impl
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/litgpt/model.py", line 187, in forward
[rank0]:     x = self.mlp(self.norm_2(x)) + x
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1657, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1722, in _call_impl
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/litgpt/model.py", line 347, in forward
[rank0]:     token_idx, expert_idx = torch.where(mask)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/flop_counter.py", line 693, in __torch_dispatch__
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
[rank0]:     return self_._op(*args, **kwargs)
[rank0]: NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/docs/main/notes/custom_operators.html

cc @crcrpar

IvanYashchuk commented 5 days ago

~The Mixture of Experts models and Mixtral particularly are not currently supported. There's a tracking issue https://github.com/Lightning-AI/lightning-thunder/issues/194.~ It's a problem with the benchmark script itself, see the next comment.

IvanYashchuk commented 5 days ago

measure_flops(meta_model, model_fwd, model_loss) inside benchmarks/benchmark_litgpt.py uses meta tensors and PyTorch rightfully errors out for this case. (added in https://github.com/Lightning-AI/lightning-thunder/commit/348597fd045903aa232b6811bf6bffa392edbd65)

How useful is this approach to computing model flops? https://github.com/Lightning-AI/lightning-thunder/blob/72e033a0e0dfe44d4770dec2399a9058971003ec/thunder/benchmarks/benchmark_litgpt.py#L387

Can we remove it from the benchmark script? Can we skip it for unsupported models? @parthmannan, do you have opinions here?

tfogal commented 1 day ago

triage review:

parthmannan commented 20 hours ago

Yes, we can make it optional and can be enabled only with a script argument. I'll make those changes and submit a PR. cc - @carmocca Carlos, do you think we should be file a bug in lightning as the error comes from the throughput measurement code from lightning Fabric?