pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 652 forks source link

`Spectrogram` crash with `torch.compile` #3601

Closed tarepan closed 1 year ago

tarepan commented 1 year ago

🐛 Describe the bug

torchaudio.transforms.Spectrogram raise error when forward if we use torch.compile.
This phenomena is reproduced on both CPU and GPU (NVIDIA T4).

import torch, torchaudio

ipt = torch.tensor([1. for _ in range(1600)])
spec = torchaudio.transforms.Spectrogram()
spec_compiled = torch.compile(spec)

print(         spec(ipt).size()) # torch.Size([201, 9])
print(spec_compiled(ipt).size())
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in call_user_compiler(self, gm)
    669             else:
--> 670                 compiled_fn = compiler_fn(gm, self.fake_example_inputs())
    671             _step_logger()(logging.INFO, f"done compiler function {name}")

46 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/debug_utils.py](https://localhost:8080/#) in debug_wrapper(gm, example_inputs, **kwargs)
   1054         else:
-> 1055             compiled_gm = compiler_fn(gm, example_inputs)
   1056 

[/usr/local/lib/python3.10/dist-packages/torch/__init__.py](https://localhost:8080/#) in __call__(self, model_, inputs_)
   1389 
-> 1390         return compile_fx(model_, inputs_, config_patches=self.config)
   1391 

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx(model_, example_inputs_, inner_compile, config_patches)
    454         # once torchdynamo is merged into pytorch
--> 455         return aot_autograd(
    456             fw_compiler=fw_compiler,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py](https://localhost:8080/#) in compiler_fn(gm, example_inputs)
     47             with enable_aot_logging():
---> 48                 cg = aot_module_simplified(gm, example_inputs, **kwargs)
     49                 counters["aot_autograd"]["ok"] += 1

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, hasher_type, static_argnums, keep_inference_input_mutations)
   2821 
-> 2822     compiled_fn = create_aot_dispatcher_function(
   2823         functional_call,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
   2514 
-> 2515         compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
   2516 

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in aot_wrapper_dedupe(flat_fn, flat_args, aot_config, compiler_fn)
   1714         if ok:
-> 1715             return compiler_fn(flat_fn, leaf_flat_args, aot_config)
   1716 

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in aot_dispatch_base(flat_fn, flat_args, aot_config)
   1327     with context(), track_graph_compiling(aot_config, "inference"):
-> 1328         compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
   1329 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in fw_compiler(model, example_inputs)
    429         model = convert_outplace_to_inplace(model)
--> 430         return inner_compile(
    431             model,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/debug_utils.py](https://localhost:8080/#) in debug_wrapper(gm, example_inputs, **kwargs)
    594         else:
--> 595             compiled_fn = compiler_fn(gm, example_inputs)
    596 

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/debug.py](https://localhost:8080/#) in inner(*args, **kwargs)
    238             with DebugContext():
--> 239                 return fn(*args, **kwargs)
    240 

[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in inner(*args, **kwds)
     78             with self._recreate_cm():
---> 79                 return func(*args, **kwds)
     80         return inner

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx_inner(gm, example_inputs, cudagraphs, num_fixed, is_backward, graph_id)
    176             graph.run(*example_inputs)
--> 177             compiled_fn = graph.compile_to_fn()
    178 

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in compile_to_fn(self)
    585     def compile_to_fn(self):
--> 586         return self.compile_to_module().call
    587 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in compile_to_module(self)
    570 
--> 571         code = self.codegen()
    572         if config.debug:

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in codegen(self)
    521         assert self.scheduler is not None  # mypy can't figure this out
--> 522         self.scheduler.codegen()
    523         assert self.wrapper_code is not None

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in codegen(self)
   1176 
-> 1177         self.flush()

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in flush(self)
   1094         for backend in self.backends.values():
-> 1095             backend.flush()
   1096         self.free_buffers()

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/cpp.py](https://localhost:8080/#) in flush(self)
   1974     def flush(self):
-> 1975         self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
   1976         self.get_kernel_group()

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/cpp.py](https://localhost:8080/#) in codegen_define_and_call(self, wrapper)
   2003         kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix()
-> 2004         arg_defs, call_args, arg_types = self.args.cpp_argdefs()
   2005         arg_defs = ",\n".ljust(25).join(arg_defs)

[/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/common.py](https://localhost:8080/#) in cpp_argdefs(self)
    321             dtype = buffer_types[outer]
--> 322             cpp_dtype = DTYPE_TO_CPP[dtype]
    323             arg_defs.append(f"{cpp_dtype}* __restrict__ {inner}")

KeyError: torch.complex64

The above exception was the direct cause of the following exception:

BackendCompilerFailed                     Traceback (most recent call last)
[<ipython-input-2-edc57f795be6>](https://localhost:8080/#) in <cell line: 8>()
      6 
      7 print(         spec(ipt).size()) # torch.Size([201, 9])
----> 8 print(spec_compiled(ipt).size())

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
     80 
     81     def forward(self, *args, **kwargs):
---> 82         return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
     83 
     84 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    207             dynamic_ctx.__enter__()
    208             try:
--> 209                 return fn(*args, **kwargs)
    210             finally:
    211                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torchaudio/transforms/_transforms.py](https://localhost:8080/#) in forward(self, waveform)
    108             Fourier bins, and time is the number of window hops (n_frame).
    109         """
--> 110         return F.spectrogram(
    111             waveform,
    112             self.pad,

[/usr/local/lib/python3.10/dist-packages/torchaudio/functional/functional.py](https://localhost:8080/#) in spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized, center, pad_mode, onesided, return_complex)
    117         waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
    118 
--> 119     frame_length_norm, window_norm = _get_spec_norms(normalized)
    120 
    121     # pack batch

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in catch_errors(frame, cache_size)
    335 
    336         with compile_lock:
--> 337             return callback(frame, cache_size, hooks)
    338 
    339     catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame(frame, cache_size, hooks)
    402         counters["frames"]["total"] += 1
    403         try:
--> 404             result = inner_convert(frame, cache_size, hooks)
    405             counters["frames"]["ok"] += 1
    406             return result

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    102         torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    103         try:
--> 104             return fn(*args, **kwargs)
    105         finally:
    106             torch._C._set_grad_enabled(prior_grad_mode)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame_assert(frame, cache_size, hooks)
    260         initial_grad_state = torch.is_grad_enabled()
    261 
--> 262         return _compile(
    263             frame.f_code,
    264             frame.f_globals,

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    161                 compilation_metrics[key] = []
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0
    165             # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    322         for attempt in itertools.count():
    323             try:
--> 324                 out_code = transform_code_object(code, transform)
    325                 orig_code_map[out_code] = code
    326                 break

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in transform_code_object(code, transformations, safe)
    443     propagate_line_nums(instructions)
    444 
--> 445     transformations(instructions, code_options)
    446     return clean_and_assemble_instructions(instructions, keys, code_options)[1]
    447 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in transform(instructions, code_options)
    309             mutated_closure_cell_contents,
    310         )
--> 311         tracer.run()
    312         output = tracer.output
    313         assert output is not None

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
   1724     def run(self):
   1725         _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1726         super().run()
   1727 
   1728     def match_nested_cell(self, name, cell):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
    574                 self.instruction_pointer is not None
    575                 and not self.output.should_exit
--> 576                 and self.step()
    577             ):
    578                 pass

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
    538             if not hasattr(self, inst.opname):
    539                 unimplemented(f"missing: {inst.opname}")
--> 540             getattr(self, inst.opname)(inst)
    541 
    542             return inst.opname != "RETURN_VALUE"

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in RETURN_VALUE(self, inst)
   1790         )
   1791         log.debug("RETURN_VALUE triggered compile")
-> 1792         self.output.compile_subgraph(
   1793             self, reason=GraphCompileReason("return_value", [self.frame_summary()])
   1794         )

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_subgraph(self, tx, partial_convert, reason)
    515             # optimization to generate better code in a common case
    516             self.add_output_instructions(
--> 517                 self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
    518                 + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
    519             )

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_and_call_fx_graph(self, tx, rv, root)
    586         assert_no_fake_params_or_buffers(gm)
    587         with tracing(self.tracing_context):
--> 588             compiled_fn = self.call_user_compiler(gm)
    589         compiled_fn = disable(compiled_fn)
    590 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    161                 compilation_metrics[key] = []
    162             t0 = time.time()
--> 163             r = func(*args, **kwargs)
    164             time_spent = time.time() - t0
    165             # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in call_user_compiler(self, gm)
    673         except Exception as e:
    674             compiled_fn = gm.forward
--> 675             raise BackendCompilerFailed(self.compiler_fn, e) from e
    676         return compiled_fn
    677 

BackendCompilerFailed: debug_wrapper raised KeyError: torch.complex64

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.0.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.27.4 Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.109+-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 2 On-line CPU(s) list: 0,1 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) CPU @ 2.20GHz CPU family: 6 Model: 79 Thread(s) per core: 2 Core(s) per socket: 1 Socket(s): 1 Stepping: 0 BogoMIPS: 4399.99 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities Hypervisor vendor: KVM Virtualization type: full L1d cache: 32 KiB (1 instance) L1i cache: 32 KiB (1 instance) L2 cache: 256 KiB (1 instance) L3 cache: 55 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0,1 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable; SMT Host state unknown Vulnerability Meltdown: Vulnerable Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable

Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] torch==2.0.1+cu118 [pip3] torchaudio==2.0.2+cu118 [pip3] torchdata==0.6.1 [pip3] torchsummary==1.5.1 [pip3] torchtext==0.15.2 [pip3] torchvision==0.15.2+cu118 [pip3] triton==2.0.0 [conda] Could not collect

mthrok commented 1 year ago

Hi @tarepan

Can you try the nightly build? I tried it on nightly and it worked fine though it gives warning about complex being slow.

torch.Size([201, 9])
/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/lowering.py:1159: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
Using FallbackKernel: aten.reshape
Using FallbackKernel: aten.abs
torch.Size([201, 9])
tarepan commented 1 year ago

As you sugested, it works in nightly build.
Thank you very much!

In my opinion, we can close this issue after next version (non-nightly) release of torchaudio.
I will watch the release.

Reproduce

# !pip uninstall -y torch torchvision torchaudio
# !pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

import torch, torchaudio

ipt = torch.tensor([1. for _ in range(1600)])
spec = torchaudio.transforms.Spectrogram()
spec_compiled = torch.compile(spec)

print(         spec(ipt).size()) # torch.Size([201, 9])
print(spec_compiled(ipt).size())
# /usr/local/lib/python3.10/dist-packages/torch/_inductor/lowering.py:1484: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
#   warnings.warn(
# No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
# torch.Size([201, 9])