revsic / jax-variational-diffwave

Jax/Flax implementation of Variational-DiffWave.
MIT License
40 stars 6 forks source link
diffwave flax jax variational-diffusion-models vdm

jax-variational-diffwave

WARNING: This repository is pended maintaining, continuous time diffusion was failed.

Jax/Flax implementation of Variational-DiffWave. (Zhifeng Kong et al., 2020, Diederik P. Kingma et al., 2021.)

Requirements

Tested in python 3.7.9 conda environment, requirements.txt

Usage

To train model, run train.py. \ Checkpoint will be written on TrainConfig.ckpt, tensorboard summary on TrainConfig.log.

python train.py --data-dir /datasets/ljspeech --from-raw
tensorboard --logdir ./log/

To start to train from previous checkpoint, --load-step is available.

python train.py --load-epoch 10 --config ./ckpt/l1.json

[TODO] To synthesize test set, run synth.py.

python synth.py

[TODO] Pretrained checkpoints are relased on releases.

To use pretrained model, download files and unzip it. \ Checkout git repository to proper commit tags and following is sample script.

with open('l1.json') as f:
    config = Config.load(json.load(f))

diffwave = VLBDiffWaveApp(config.model)
diffwave.restore('./l1/l1_99.ckpt')

# mel: [B, T, mel]
audio, _ = diffwave(mel, timesteps=50, key=jax.random.PRNGKey(0))