mkotha / WaveRNN

A WaveRNN implementation
MIT License
198 stars 48 forks source link

WaveRNN + VQ-VAE

This is a Pytorch implementation of WaveRNN. Currently 3 top-level networks are provided:

Audio samples.

It has been tested with the following datasets.

Multispeaker datasets:

Single-speaker datasets:

Preparation

Requirements

Create config.py

cp config.py.example config.py

Preparing VCTK

You can skip this section if you don't need a multi-speaker dataset.

  1. Download and uncompress the VCTK dataset.
  2. python preprocess_multispeaker.py /path/to/dataset/VCTK-Corpus/wav48 /path/to/output/directory
  3. In config.py, set multi_speaker_data_path to point to the output directory.

Preparing LJ-Speech

You can skip this section if you don't need a single-speaker dataset.

  1. Download and uncompress the LJ speech dataset.
  2. python preprocess16.py /path/to/dataset/LJSpeech-1.1/wavs /path/to/output/directory
  3. In config.py, set single_speaker_data_path to point to the output directory.

Usage

wavernn.py is the entry point:

$ python wavernn.py

By default, it trains a VQ-VAE model. The -m option can be used to tell the the script to train a different model.

Trained models are saved under the model_checkpoints directory.

By default, the script will take the latest snapshot and continues training from there. To train a new model freshly, use the --scratch option.

Every 50k steps, the model is run to generate test audio outputs. The output goes under the model_outputs directory.

When the -g option is given, the script produces the output using the saved model, rather than training it.

Deviations from the papers

I deviated from the papers in some details, sometimes because I was lazy, and sometimes because I was unable to get good results without it. Below is a (probably incomplete) list of deviations.

All models:

VQ-VAE:

Context stacks

The VQ-VAE implementation uses a WaveRNN-based decoder instead of a WaveNet- based decoder found in the paper. This is a WaveRNN network augmented with a context stack to extend the receptive field. This network is defined in layers/overtone.py.

The network has 6 convolutions with stride 2 to generate 64x downsampled 'summary' of the waveform, and then 4 layers of upsampling RNNs, the last of which is the WaveRNN layer. It also has U-net-like skip connections that connect layers with the same operating frequency.

Acknowledgement

The code is based on fatchord/WaveRNN.