pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 653 forks source link

`torch.compile` is not working with `torchaudio.functional.lfilter` #3709

Open jhauret opened 11 months ago

jhauret commented 11 months ago

🐛 Describe the bug

The new torch.compile feature does not work with torchaudio.functional.lfilter. My understanding is that torch.compile needs to know the shapes of the tensor, but these shapes are not fixed as the filter is applied recursively to any possible waveform.


Minimal error code:

import os
import torch
import torchaudio

# Set the environment variables to enable more information in error logs
os.environ["TORCH_LOGS"] = "+dynamo"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

# Define some tensors
a = torch.Tensor((1, 0))
b = torch.Tensor((1, 0))
waveform = torch.ones((1, 1, 16000))

# Define the function to compile
def my_lfilter(x):
    return torchaudio.functional.lfilter(x, a, b)

# Classic non compiled filtering
filtered_waveform = my_lfilter(waveform)

# Compiling the function, this is passing
my_lfilter_compiled = torch.compile(my_lfilter)

# The error is here, when calling the compiled function
filtered_waveform_compiled = my_lfilter_compiled(waveform)

Error:

/home/julien/.pyenv/versions/pulse/bin/python3 /home/julien/Desktop/Pulse/working_repo/pulse_github/pulse_ai/training_recipes/minimal_compile_error.py 
Traceback (most recent call last):
  File "/home/julien/Desktop/Pulse/working_repo/pulse_github/pulse_ai/training_recipes/minimal_compile_error.py", line 27, in <module>
    filtered_waveform_compiled = my_lfilter_compiled(waveform)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 729, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1187, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1274, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1376, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1337, in get_fake_value
    return wrap_fake_exception(
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
    return fn()
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1338, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1410, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1397, in run_node
    return node.target(*args, **kwargs)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torchaudio._lfilter(*(FakeTensor(..., size=(1, 1, 16000)), FakeTensor(..., size=(1, 2)), FakeTensor(..., size=(1, 2))), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
   File "/home/julien/Desktop/Pulse/working_repo/pulse_github/pulse_ai/training_recipes/minimal_compile_error.py", line 17, in my_lfilter
    return torchaudio.functional.lfilter(x, a, b)
  File "/home/julien/.pyenv/versions/pulse/lib/python3.10/site-packages/torchaudio/functional/filtering.py", line 1057, in lfilter
    output = _lfilter(waveform, a_coeffs, b_coeffs)

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Process finished with exit code 1

Versions

Collecting environment information... PyTorch version: 2.1.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.27.4 Libc version: glibc-2.35

Python version: 3.10.6 (main, Sep 7 2023, 09:03:30) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.2.0-36-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1650 Nvidia driver version: 535.129.03 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 39 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz CPU family: 6 Model: 158 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 13 CPU max MHz: 5000,0000 CPU min MHz: 800,0000 BogoMIPS: 4800.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust sgx bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp sgx_lc md_clear flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 2 MiB (8 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Mitigation; Microcode Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Mitigation; Microcode Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.4 [pip3] onnx==1.14.1 [pip3] onnxruntime==1.15.1 [pip3] pytorch-lightning==2.0.9.post0 [pip3] pytorch-ranger==0.1.1 [pip3] pytorch-seed==0.2.0 [pip3] torch==2.1.0 [pip3] torch-audiomentations==0.11.0 [pip3] torch-optimizer==0.1.0 [pip3] torch-pitch-shift==1.2.4 [pip3] torch-stoi==0.1.2 [pip3] torch-utilities==1.1.2 [pip3] torchaudio==2.1.0 [pip3] torchmetrics==0.7.3 [pip3] triton==2.1.0 [conda] Could not collect

mthrok commented 11 months ago

Hi

Thanks for the report, however, this project no longer has an active maintainer.

jhauret commented 11 months ago

Hi and thanks for your reply,

I am quite surprised to hear this! torchaudio is a great library that is very useful for the community! Is this only temporary?