pytorch-labs / tritonbench

Tritonbench is a collection of PyTorch custom operators with example inputs to measure their performance.
BSD 3-Clause "New" or "Revised" License
18 stars 3 forks source link

layer_norm backward problem #40

Open FindHao opened 1 week ago

FindHao commented 1 week ago

The bwd and fwd_bwd tests for layer_norm failed.

Error string is RuntimeError: This backward function was compiled with non-empty donated buffers which requires create_graph=False and retain_graph=False. Please keep backward(create_graph=False, retain_graph=False) across all backward() function calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.

Test Plan:

% python run.py --op layer_norm --precision fp32 --metrics latency,accuracy,speedup,gpu_peak_mem,mem_footprint --mode fwd_bwd

  3%|████████▏                                                                                                                                                                                                                                            | 1/30 [00:01<00:55,  1.91s/it]
Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 716, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 704, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 923, in _do_bench
    metrics.latency = triton.testing.do_bench(
                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/home/yhao/ptd/tritonbench/tritonbench/utils/triton_op.py", line 627, in <lambda>
    fwd_bwd_fn = lambda: (fwd_fn(), bwd_fn())
                                    ^^^^^^^^
  File "/home/yhao/ptd/tritonbench/tritonbench/operators/layer_norm/operator.py", line 50, in <lambda>
    return lambda: y.backward(dy, retain_graph=True)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_tensor.py", line 624, in backward
    torch.autograd.backward(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1705, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1695, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2022, in _backward_impl
    torch._check(
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/__init__.py", line 1615, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/yhao/.conda/envs/ptd/lib/python3.11/site-packages/torch/__init__.py", line 1597, in _check_with
    raise error_type(message_evaluated)
RuntimeError: This backward function was compiled with non-empty donated buffers which requires create_graph=False and retain_graph=False. Please keep backward(create_graph=False, retain_graph=False) across all backward() function calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.
FindHao commented 1 week ago

Not sure if set torch._functorch.config.donated_buffer=False is the correct way to solve it.

xuzhao9 commented 21 hours ago

This error only happens when --num_input input numbers is greater than 1