pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.97k stars 22.63k forks source link

compiled LSTM performance is worse than not compiled #140845

Open mstebelev opened 5 hours ago

mstebelev commented 5 hours ago

🐛 Describe the bug

I tried to export and compile LSTM model and it's performance finally is much worse than in CUDA in terms of total kernel time and in number of operations

import torch

class LSTM(torch.nn.LSTM):
    def _update_flat_weights(self):
        return

    @torch.compiler.disable
    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

class M(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.lstm = LSTM(input_size=64, hidden_size=64, batch_first=True)
  def forward(self, x):
    return self.lstm(x)

m = M().cuda()
inputs = (torch.randn(32, 128, 64).cuda(),)
exported_program = torch.export.export(
  m, inputs,
  strict=False,
  #dynamic_shapes=({1: torch.export.Dim('history')},)
)
print(exported_program)
exported_m = exported_program.module()

profiler_kwargs = {
    'activities': [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    'record_shapes': False,
    'profile_memory': False,
    'with_stack': True,
    'with_flops': False,
    'with_modules': True,
}
exported_m(*inputs)
profiler = torch.profiler.profile(**profiler_kwargs)
with profiler:
  exported_m(*inputs)
profiler.export_chrome_trace(str('/tmp/test_compiled/trace_exported.json'))

exported_m.compile()

exported_m(*inputs)
profiler = torch.profiler.profile(**profiler_kwargs)
with profiler:
  exported_m(*inputs)
profiler.export_chrome_trace(str('/tmp/test_compiled/trace_compiled.json'))

the printed program is

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lstm_weight_ih_l0: "f32[256, 64]", p_lstm_weight_hh_l0: "f32[256, 64]", p_lstm_bias_ih_l0: "f32[256]", p_lstm_bias_hh_l0: "f32[256]", c_lstm_lifted_tensor_0: "f32[256, 64]", c_lstm_lifted_tensor_1: "f32[256, 64]", c_lstm_lifted_tensor_2: "f32[256]", c_lstm_lifted_tensor_3: "f32[256]", x: "f32[32, 128, 64]"):
             # File: [/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085](https://vscode-remote+ssh-002dremote-002bvscode-0040ws-002dmstebelev-002dmstebelev-002dtrain-002ecoder-002epods-002emax-002eavride-002eai.vscode-resource.vscode-cdn.net/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085) in forward, code: h_zeros = torch.zeros(
            zeros: "f32[1, 32, 64]" = torch.ops.aten.zeros.default([1, 32, 64], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

             # File: [/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1092](https://vscode-remote+ssh-002dremote-002bvscode-0040ws-002dmstebelev-002dmstebelev-002dtrain-002ecoder-002epods-002emax-002eavride-002eai.vscode-resource.vscode-cdn.net/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1092) in forward, code: c_zeros = torch.zeros(
            zeros_1: "f32[1, 32, 64]" = torch.ops.aten.zeros.default([1, 32, 64], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

             # File: [/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123](https://vscode-remote+ssh-002dremote-002bvscode-0040ws-002dmstebelev-002dmstebelev-002dtrain-002ecoder-002epods-002emax-002eavride-002eai.vscode-resource.vscode-cdn.net/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/pytorch/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123) in forward, code: result = _VF.lstm(
            lstm = torch.ops.aten.lstm.input(x, [zeros, zeros_1], [c_lstm_lifted_tensor_0, c_lstm_lifted_tensor_1, c_lstm_lifted_tensor_2, c_lstm_lifted_tensor_3], True, 1, 0.0, True, False, True);  x = zeros = zeros_1 = c_lstm_lifted_tensor_0 = c_lstm_lifted_tensor_1 = c_lstm_lifted_tensor_2 = c_lstm_lifted_tensor_3 = None
            getitem: "f32[32, 128, 64]" = lstm[0]
            getitem_1: "f32[1, 1, 32, 64]" = lstm[1]
            getitem_2: "f32[1, 1, 32, 64]" = lstm[2];  lstm = None
            return (getitem, getitem_1, getitem_2)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lstm_weight_ih_l0'), target='lstm.weight_ih_l0', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lstm_weight_hh_l0'), target='lstm.weight_hh_l0', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lstm_bias_ih_l0'), target='lstm.bias_ih_l0', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lstm_bias_hh_l0'), target='lstm.bias_hh_l0', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lstm_lifted_tensor_0'), target='lstm.lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lstm_lifted_tensor_1'), target='lstm.lifted_tensor_1', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lstm_lifted_tensor_2'), target='lstm.lifted_tensor_2', persistent=None), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lstm_lifted_tensor_3'), target='lstm.lifted_tensor_3', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_2'), target=None)])
Range constraints: {}

The problems:

  1. It does some work with data pointers, that's why I redefined _update_flat_weights in my custom code
  2. after the fix above it works, but looks like it unrolls LSTM using this implementation. And that's why it doesn't work with dynamic shapes.
  3. With this implementation it's performance is much worse, it creates a lot of kernels instead of a few cudnn kernels in implementation before calling exported_m.compile(). I'm attaching screenshots and traces.

My question is: is there any way to fallback to cudnn implementation for LSTM after calling .compile(), but compile other modules in the model with triton?

Exported compiled model: trace_compiled.json

Screenshot 2024-11-15 at 20 31 02

Exported, but not compiled model: trace_exported.json

Screenshot 2024-11-15 at 20 33 36

Error logs

No response

Versions

Unfotunately it has failed, but my torch version is 2.5.1

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 24366  100 24366    0     0  75349      0 --:--:-- --:--:-- --:--:-- 75436
Collecting environment information...
Traceback (most recent call last):
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 693, in <module>
    main()
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 676, in main
    output = get_pretty_env_info()
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 671, in get_pretty_env_info
    return pretty_str(get_env_info())
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 496, in get_env_info
    pip_version, pip_list_output = get_pip_packages(run_lambda)
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 453, in get_pip_packages
    out = run_with_pip([sys.executable, '-mpip'])
  File "/home/vscode/.cache/bazel/_bazel_vscode/93fd2cd9b3c5d87ae416561bff883334/execroot/__main__/bazel-out/k8-opt/bin/prediction/e.jupyter.runfiles/__main__/collect_env.py", line 448, in run_with_pip
    for line in out.splitlines()
AttributeError: 'NoneType' object has no attribute 'splitlines'

cc @ezyang @chauhang @penguinwu

mstebelev commented 4 hours ago

Related issues: https://github.com/pytorch/pytorch/issues/91439 https://github.com/pytorch/pytorch/issues/115092