pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.48k stars 640 forks source link

Add support for Modified Discrete Cosine Transform (MDCT) #2696

Open Kinyugo opened 1 year ago

Kinyugo commented 1 year ago

🚀 The feature

The Modified Discrete Cosine Transform (MDCT) is a perfectly invertible transform that can be used for feature extraction. It can be used as an alternative to MelSpectrograms especially where an invertible transform is desired such as audio synthesis in a more compressed space.

Motivation, pitch

I am working on a audio synthesis project and an invertible transformation that is desirable in this case as opposed to something like a MelSpectrogram. The MDCT is a viable alternative. However, there is no implementation of it in torchaudio.

Alternatives

I have tried working with MelSpectrogram transform. However, pitch reconstruction is cumbersome and requires implementation of complex neural vocoders which is undesirable in my case.

Additional context

There is a numpy implementation Zaf as well as a pypi package mdct. I would like to assist in the porting of these implementations to torchaudio. However, I require some guidance on how to go about it. Any help would be appreciated.

Currently I have a naive implementation of a 1 to 1 copy of the Zaf implementation in pytorch. I think there is room for a lot of optimization.

import math
from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import functional as F

def mdct(x: Tensor,
         window: Union[Callable[..., Tensor], Tensor],
         window_length: int,
         hop_length: Optional[int] = None,
         center: bool = True,
         pad_mode: str = "reflect") -> Tensor:
    # Initialize the window
    if callable(window):
        window = window(window_length)

    if hop_length is None:
        hop_length = window_length // 2

    # Flatten the input tensor
    shape = x.shape
    x = x.reshape((-1, shape[-1]))

    # Derive the number of frequencies and frames
    n_freqs = window_length // 2
    n_frames = int(math.ceil(shape[-1] / hop_length)) + 1

    # Center pad the signal
    if center:
        x = F.pad(x, (hop_length, (n_frames + 1) * hop_length - shape[-1]),
                  mode=pad_mode)

    # Initialize the mdct
    x_mdct = torch.zeros((x.shape[0], n_freqs, n_frames), device=x.device)

    # Prepare the pre&post processing
    preprocess_arr = torch.exp(
        -1j * torch.pi / window_length *
        torch.arange(0, window_length, device=x.device)).unsqueeze(0)
    postprocess_arr = torch.exp(
        -1j * torch.pi / window_length * (window_length / 2 + 1) *
        torch.arange(0.5, window_length / 2 + 0.5,
                     device=x.device)).unsqueeze(0)

    # Loop over time frames
    i = 0
    for j in range(n_frames):
        # Window the signal
        x_segment = x[:, i:i + window_length] * window
        i = i + hop_length

        # Compute the fourier transform of the windowed signal
        x_segment = torch.fft.fft(x_segment * preprocess_arr)

        x_mdct[:, :, j] = torch.real(x_segment[:, :n_freqs] * postprocess_arr)

    x_mdct = x_mdct.reshape(shape[:-1] + x_mdct.shape[-2:])

    return x_mdct
mthrok commented 1 year ago

Hi @Kinyugo

Welcome to torchaudio project, and thanks for the proposal. This sounds a good addition to torchaudio. cc @pytorch/team-audio-core

Here is the general steps to add it to torchaudio.

  1. Add stateless implementation (the base function) to torchaudio.functional. It seems that the implementation suits in functional.py near spectrogram / mel scale.
  2. [can be deferred] Add object-oriented implementation, which wraps the functional and cache the window and other parameters, to torchaudio.transforms.
  3. Add test* to torchaudio_unittest/functional. The difficulty here is that how to define the "correctness" of implementation. One way to do is to port the mdct as utility and compare the result against them. 4. [Optional] If the implementation should support autograd and torchscript, these tests should be added as well. This should be easy as the tests fixtures are already there so it would be just a few lines of code.
  4. Add documentation for functional and transforms.

*If you are not used to writing a test, let us know, we can take over.

Notes:

Now about the implementation;

  1. I think looping through the frames and calling fft each time will be slow. At a glance, there seems to be no dependency across time axis. Would it be possible to vectorize the computation there? (stack the segments and call fft once)
  2. Though the existing implementations (Spectrogram and such) accepts callable as window function, this is unnecessary and does not add value. (because users can just call the function and create window before calling the function) So I suggest to remove it.
mthrok commented 1 year ago

Also please refer to CONTRIBUTING.md for setting up the development environment.

Kinyugo commented 1 year ago

@mthrok Thank you for the guidelines on how to contribute.

About the implementation:

  1. Borrowing an idea from other implementations, I think it is possible to define the operations using 1D convolutions so that we ditch the for loops.
  2. The window as a callable is a convenience to allow the user to pass either kaiser bessel derived or vorbis window functions both of which will be implemented.

Let me know of your thoughts and suggestions on how we can go about it.

Kinyugo commented 1 year ago

I have also managed to port a 1 to 1 implementation of the inverse. As it is supposed to be an invertible operation we can check the reconstruction error to verify some correctness. Though I am not sure about the mathematical correctness of the implementation.

import math
from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import functional as F

def mdct(x: Tensor,
         window: Union[Callable[..., Tensor], Tensor],
         window_length: int,
         hop_length: Optional[int] = None,
         center: bool = True,
         pad_mode: str = "constant") -> Tensor:
    # Initialize the window
    if callable(window):
        window = window(window_length)

    if hop_length is None:
        hop_length = window_length // 2

    # Flatten the input tensor
    shape = x.shape
    x = x.reshape((-1, shape[-1]))

    # Derive the number of frequencies and frames
    n_freqs = window_length // 2
    n_frames = int(math.ceil(shape[-1] / hop_length)) + 1

    # Center pad the signal
    if center:
        x = F.pad(x, (hop_length, hop_length), mode=pad_mode)

    # Initialize the mdct
    x_mdct = torch.zeros((x.shape[0], n_freqs, n_frames), device=x.device)

    # Prepare the pre&post processing
    preprocess_arr = torch.exp(
        -1j * torch.pi / window_length *
        torch.arange(0, window_length, device=x.device)).unsqueeze(0)
    postprocess_arr = torch.exp(
        -1j * torch.pi / window_length * (window_length / 2 + 1) *
        torch.arange(0.5, window_length / 2 + 0.5,
                     device=x.device)).unsqueeze(0)

    # Loop over time frames
    i = 0
    for j in range(n_frames):
        # Window the signal
        x_segment = x[:, i:i + window_length] * window
        i = i + hop_length

        # Compute the fourier transform of the windowed signal
        x_segment = torch.fft.fft(x_segment * preprocess_arr)

        x_mdct[:, :, j] = torch.real(x_segment[:, :n_freqs] * postprocess_arr)

    x_mdct = x_mdct.reshape(shape[:-1] + x_mdct.shape[-2:])

    return x_mdct

def imdct(x: Tensor,
          window: Union[Callable[..., Tensor], Tensor],
          window_length: int,
          hop_length: Optional[int] = None,
          center: bool = True,
          pad_mode: str = "constant") -> Tensor:
    # Initialize the window
    if callable(window):
        window = window(window_length)

    if hop_length is None:
        hop_length = window_length // 2

    # Flatten the input tensor
    shape = x.shape
    n_freqs, n_frames = x.shape[-2:]
    x = x.reshape((-1, n_freqs, n_frames))

    # Derive the number of samples
    n_samples = hop_length * (n_frames + 1)

    # Initialize the signal
    x_imdct = torch.zeros((x.shape[0], n_samples), device=x.device)

    # Prepare the pre&post processing
    preprocess_arr = (torch.exp(
        -1j * torch.pi / (2 * n_freqs) * (n_freqs + 1) *
        torch.arange(0, n_freqs, device=x.device))).unsqueeze(dim=-1)
    postprocess_arr = (torch.exp(-1j * torch.pi / (2 * n_freqs) * torch.arange(
        0.5 + n_freqs / 2, 2 * n_freqs + n_freqs / 2 + 0.5)) /
                       n_freqs).unsqueeze(dim=-1)

    x = torch.fft.fft(x * preprocess_arr, n=2 * n_freqs, axis=1)

    # Apply the window function to the frames after post-processing
    x = 2 * (torch.real(x * postprocess_arr) * window.unsqueeze(dim=-1))

    # Loop over the time frames
    i = 0
    for j in range(n_frames):
        # Recover the signal with the time-domain aliasing cancelling principle
        x_imdct[:, i:i +
                window_length] = x_imdct[:, i:i + window_length] + x[:, :, j]
        i = i + hop_length

    # Remove padding
    if center:
        x_imdct = x_imdct[:, hop_length:-hop_length]
    x_imdct = x_imdct.reshape((*shape[:-2], -1))

    return x_imdct

  x = torch.randn(1, 2, 131_072)
  window_length = 1024
  w = torch.sin(torch.pi / 2 * pow(
      torch.sin(torch.pi / window_length *
                torch.arange(0.5, window_length + 0.5)), 2)) # vorbis window
  y = mdct(x, w, window_length)
  z = imdct(y, w, window_length)

  loss = torch.mean(torch.abs(x - z))
  print(loss) # Around 1e-5

I am yet to figure out how to optimize the loops.

Kinyugo commented 1 year ago

@mthrok I managed to get a working vectorized implementation of the MDCT algorithm. However, I have trouble setting up the development environment on my local machine to open a PR. How do I go about it?

carolineechen commented 1 year ago

Hi @Kinyugo, you can refer to Contributing.md to get started with dev environment set up and process. Feel free to let us know if you encounter any errors while doing so!

Kinyugo commented 1 year ago

Hello @carolineechen. I followed the steps in the contributing guide but ran into some problems installing the package i.e: python setup.py develop. While building sox I am running into an error:

subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--target', 'install']' returned non-zero exit status 1.
carolineechen commented 1 year ago

Hi @Kinyugo, can you post the full error message, as well as your environment, so we can get a better idea of what the error is?

To get the environment versions, you can use

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
Kinyugo commented 1 year ago

Full Error Message

/home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/third_party/sox/../install/lib/libvorbisfile.a  /home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/third_party/sox/../install/lib/libvorbis.a  /home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/third_party/sox/../install/lib/libogg.a  /usr/lib/gcc/x86_64-linux-gnu/9/libgomp.so  /usr/lib/x86_64-linux-gnu/libpthread.so && :
/usr/bin/ld: cannot find -lmkl_intel_ilp64
/usr/bin/ld: cannot find -lmkl_core
/usr/bin/ld: cannot find -lmkl_intel_thread
collect2: error: ld returned 1 exit status
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/setup.py", line 182, in <module>
    _main()
  File "/home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/setup.py", line 147, in _main
    setup(
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/__init__.py", line 87, in setup
    return distutils.core.setup(**attrs)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
    return run_commands(dist)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
    dist.run_commands()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 973, in run_commands
    self.run_command(cmd)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/dist.py", line 1217, in run_command
    super().run_command(command)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 992, in run_command
    cmd_obj.run()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/command/develop.py", line 34, in run
    self.install_for_development()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/command/develop.py", line 114, in install_for_development
    self.run_command('build_ext')
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 319, in run_command
    self.distribution.run_command(command)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/dist.py", line 1217, in run_command
    super().run_command(command)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 992, in run_command
    cmd_obj.run()
  File "/home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/tools/setup_helpers/extension.py", line 78, in run
    super().run()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 84, in run
    _build_ext.run(self)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run
    self.build_extensions()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 466, in build_extensions
    self._build_extensions_serial()
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 492, in _build_extensions_serial
    self.build_extension(ext)
  File "/home/kinyugo-maina/Learning/ML/AudioGeneration/torchaudio/audio/tools/setup_helpers/extension.py", line 153, in build_extension
    subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
  File "/home/kinyugo-maina/miniconda3/envs/torchaudio/lib/python3.10/subprocess.py", line 369, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--target', 'install']' returned non-zero exit status 1.

Environment

Collecting environment information...
PyTorch version: 1.14.0.dev20221023
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: KDE neon User - 5.26 (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] torch==1.14.0.dev20221023
[conda] blas                      2.116                       mkl    conda-forge
[conda] blas-devel                3.9.0            16_linux64_mkl    conda-forge
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.1.0           h84fe81f_915    conda-forge
[conda] mkl-devel                 2022.1.0           ha770c72_916    conda-forge
[conda] mkl-include               2022.1.0           h84fe81f_915    conda-forge
[conda] pytorch                   1.14.0.dev20221023    py3.10_cpu_0    pytorch-nightly
[conda] pytorch-mutex             1.0                         cpu    pytorch-nightly
carolineechen commented 1 year ago

could you try downgrading mkl to version 2021.2.0, and then running

python setup.py clean
python setup.py develop
Kinyugo commented 1 year ago

The issue is still persistent.

carolineechen commented 1 year ago

Hi @Kinyugo, sorry for the late reply, looks like this might be related to #2784, and there's a potential solution offered there. We also have a team member looking into this, hopefully it can be resolved soon and we'll update this issue when it is.

As you already have a working version offline, if you would like, it might also be possible for you to create a draft PR anyways in the meantime [but harder to verify if tests compile/pass locally], and we could provide some preliminary context.

Kinyugo commented 1 year ago

Hello @carolineechen, I will look into creating a draft PR, as I would also appreciate some feedback on some design choices that I made.

Godspeed on resolving the issue. Thanks for your support.

Kinyugo commented 1 year ago

Hi @mthrok & @carolineechen, I apologize for taking too long to get back to this. I have made an implementation of the MDCT algorithm in PyTorch here. Kindly let me know your thoughts. I will open a draft PR soon. Kindly help me out with the testing as I cannot setup the development environment on my machine. I will appreciate any help or feedback. Thanks for your support.

mthrok commented 1 year ago

cc @nateanl

turian commented 1 year ago

Not to hijaak the thread, but @Kinyugo what are your thoughts on PQMF?