maum-ai / nuwave2

NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates @ INTERSPEECH 2022
https://mindslab-ai.github.io/nuwave2
BSD 3-Clause "New" or "Revised" License
278 stars 21 forks source link
deep-learning neural-audio-upsampling pytorch super-resolution upsampling

NU-Wave2 — Official PyTorch Implementation

NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates
Seungu Han, Junhyeok Lee @ MINDsLab Inc., SNU

arXiv GitHub Repo stars githubio

Official Pytorch+Lightning Implementation for NU-Wave 2.

Official Checkpoint can be downloaded from here.

We add some additional samples for non-English voice (Korean) and ablation study without BSFT on the demo page. Please check it!

We also trained a model targeting 16 kHz (3.2 kHz ~ 16 kHz source). The Checkpoint can be downloaded from here.

Requirements

Clone our Repository

git clone --recursive https://github.com/mindslab-ai/nuwave2.git
cd nuwave2

Preprocessing

Before running our project, you need to download and preprocess dataset to .wav files

  1. Download VCTK dataset
  2. Remove speaker p280 and p315
  3. Modify path of downloaded dataset data:base_dir in hparameter.yaml
  4. run utils/flac2wav.py
    python utils/flac2wav.py

Training

  1. Adjust hparameter.yaml, especially train section.
    train:
    batch_size: 12 # Dependent on GPU memory size
    lr: 2e-4
    weight_decay: 0.00
    num_workers: 8 # Dependent on CPU cores
    gpus: 2 # number of GPUs
    opt_eps: 1e-9
    beta1: 0.9
    beta2: 0.99
    • Adjust data section in hparameters.yaml.
      data:
      timestamp_path: 'vctk-silence-labels/vctk-silences.0.92.txt'
      base_dir: '/DATA1/VCTK-0.92/wav48_silence_trimmed/'
      dir: '/DATA1/VCTK-0.92/wav48_silence_trimmed_wav/' #dir/spk/format
      format: '*mic1.wav'
      cv_ratio: (100./108., 8./108., 0.00) #train/val/test
  2. run trainer.py.
    $ python trainer.py
    • If you want to resume training from checkpoint, check parser.
      parser = argparse.ArgumentParser()
      parser.add_argument('-r', '--resume_from', type =int,\
          required = False, help = "Resume Checkpoint epoch number")
      parser.add_argument('-s', '--restart', action = "store_true",\
          required = False, help = "Significant change occured, use this")
      parser.add_argument('-e', '--ema', action = "store_true",\
          required = False, help = "Start from ema checkpoint")
      args = parser.parse_args()
    • During training, tensorboard logger is logging loss, spectrogram and audio.
      $ tensorboard --logdir=./tensorboard --bind_all

Evaluation

run for_test.py

python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}

Please check parser.

    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--resume_from', type =int,
                required = True, help = "Resume Checkpoint epoch number")
    parser.add_argument('-e', '--ema', action = "store_true",
                required = False, help = "Start from ema checkpoint")
    parser.add_argument('--save', action = "store_true",
               required = False, help = "Save file")
    parser.add_argument('--sr', type=int, \
               required=True, help="input sampling rate")

Inference

Note: If your input is downsampled (12kHz, 16kHz, etc.) audio sample with a full valid frequency component based on the corresponding sampling rate, give the parser as '--sr {Sampling rate of input audio}' without '--gt' parser.
On the other hand, if you have a 48kHz audio sample with a full valid frequency component and just want to check whether the model works well, give the parser as '--sr {Sampling rate of input which you want to check}' and add '--gt' parser.
Please check this issue for more information.

    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--checkpoint',
                        type=str,
                        required=True,
                        help="Checkpoint path")
    parser.add_argument('-i',
                        '--wav',
                        type=str,
                        default=None,
                        help="audio")
    parser.add_argument('--sr',
                        type=int,
                        required=True,
                        help="Sampling rate of input audio")
    parser.add_argument('--steps',
                        type=int,
                        required=False,
                        help="Steps for sampling")
    parser.add_argument('--gt', action="store_true",
                        required=False, help="Whether the input audio is 48 kHz ground truth audio.")
    parser.add_argument('--device',
                        type=str,
                        default='cuda',
                        required=False,
                        help="Device, 'cuda' or 'cpu'")

References

This implementation uses code from following repositories:

This README and the webpage for the audio samples are inspired by:

The audio samples on our webpage are partially derived from:

Repository Structure

.
|-- Dockerfile
|-- LICENSE
|-- README.md
|-- dataloader.py           # Dataloader for train/val(=test)
|-- diffusion.py            # DPM
|-- for_test.py             # Test with for_loop.
|-- hparameter.yaml         # Config
|-- inference.py            # Inference
|-- lightning_model.py      # NU-Wave 2 implementation.
|-- model.py                # NU-Wave 2 model based on lmnt-com's DiffWave implementation
|-- requirements.txt        # requirement libraries
|-- trainer.py              # Lightning trainer
|-- utils
|   |-- flac2wav.py             # Preprocessing
|   |-- stft.py                 # STFT layer
|   `-- tblogger.py             # Tensorboard Logger for lightning
|-- docs                    # For github.io
|   |-- ...
`-- vctk-silence-labels     # For trimming
    |-- ...

Citation & Contact

If this repository useful for your research, please consider citing!

@article{han2022nu,
  title={NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates},
  author={Han, Seungu and Lee, Junhyeok},
  journal={arXiv preprint arXiv:2206.08545},
  year={2022}
}

If you have a question or any kind of inquiries, please contact Seungu Han at hansw032@snu.ac.kr