Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy[forward-thunder] - KeyError: 't12' #1013

Closed wujingyue closed 5 days ago

wujingyue commented 3 weeks ago

Looks like a bug in thunder/core/rematerialization.py.

$ git rev-parse HEAD
c92e8a895d3f5e1df3b3731e08171d66ab13634f
$ pytest thunder/benchmarks/targets.py -k test_nanogpt_cross_entropy[forward-thunder]
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=torch.utils.benchmark.utils.timer.timer disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=True warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timeout-2.3.1, xdist-3.6.1, cov-5.0.0, random-order-1.1.1, shard-0.1.2, benchmark-4.0.0, hypothesis-6.104.2, timestamper-0.0.10, typeguard-4.3.0
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 737 items / 736 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py F                                                                                                                                                                                                                                                                                                                                                                                                                                                                 [100%]

============================================================================================================================================================================================================================================== FAILURES ===============================================================================================================================================================================================================================================
_____________________________________________________________________________________________________________________________________________________________________________________________________________________________ test_nanogpt_cross_entropy[forward-thunder] _____________________________________________________________________________________________________________________________________________________________________________________________________________________________

benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x7f77c1640ca0>, executor = <function default_thunder_dynamic_strides_executor at 0x7f77c19ab6d0>, compute_type = <ComputeType.TRAINING_FORWARD: 2>

    @pytest.mark.parametrize(
        "executor,",
        (executors + apex_executors),
        ids=(executors_ids + apex_executors),
    )
    @parametrize_compute_type
    def test_nanogpt_cross_entropy(benchmark, executor: None | Callable, compute_type: ComputeType):
        if executor is None:
            pytest.skip("Executor is unavailable")

        bench: Benchmark = NanoGPTCrossEntropyBenchmark(
            config="gpt2-xl", device="cuda:0", dtype=thunder.bfloat16, requires_grad=is_requires_grad(compute_type)
        )

        args, kwargs = bench.make_batch()
        fn = executor(bench.fn())

>       benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)

thunder/benchmarks/targets.py:274:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
thunder/benchmarks/targets.py:131: in benchmark_for_compute_type
    benchmark(fn, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:125: in __call__
    return self._raw(function_to_benchmark, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:147: in _raw
    duration, iterations, loops_range = self._calibrate_timer(runner)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:275: in _calibrate_timer
    duration = runner(loops_range)
/usr/local/lib/python3.10/dist-packages/pytest_benchmark/fixture.py:90: in runner
    function_to_benchmark(*args, **kwargs)
thunder/__init__.py:744: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/core/langctxs.py:136: in _fn
    result = fn(*args, **kwargs)
thunder/__init__.py:229: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:660: in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
thunder/executors/torch_autograd.py:233: in split_forward_backward
    fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
thunder/core/rematerialization.py:602: in rematerialize_forward_and_backward
    joint_extrace = rematerialize(joint_extrace)
thunder/core/rematerialization.py:551: in rematerialize
    updated_consumer = apply_rematerialization_for_consumer(current_producer, current_consumer, cut)
thunder/core/rematerialization.py:185: in apply_rematerialization_for_consumer
    new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

x = <TensorProxy(name="t12", dtype=thunder.dtypes.float32, shape=(2048, 1))>

>   new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
E   KeyError: 't12'

thunder/core/rematerialization.py:185: KeyError
======================================================================================================================================================================================================================================= short test summary info =======================================================================================================================================================================================================================================
FAILED thunder/benchmarks/targets.py::test_nanogpt_cross_entropy[forward-thunder] - KeyError: 't12'
============================================================================================================================================================================================================================ 1 failed, 736 deselected, 6 warnings in 4.52s ============================================================================================================================================================================================================================
wujingyue commented 2 weeks ago

FYI, this is still happening. @IvanYashchuk and @t-vi

IvanYashchuk commented 2 weeks ago

The bug is introduced in https://github.com/Lightning-AI/lightning-thunder/pull/899. After reverting changes to thunder/core/rematerialization.py the benchmark works.

@riccardofelluga has hit the same problem in his work recently.

riccardofelluga commented 2 weeks ago

Yes it looks like this logic fails when one of the new_consumer_args comes from the args of the producer and therefore it won't appear in the new_subsymbols