WARNING: This repository is pended maintaining, continuous time diffusion was failed.
Jax/Flax implementation of Variational-DiffWave. (Zhifeng Kong et al., 2020, Diederik P. Kingma et al., 2021.)
Tested in python 3.7.9 conda environment, requirements.txt
To train model, run train.py. \
Checkpoint will be written on TrainConfig.ckpt
, tensorboard summary on TrainConfig.log
.
python train.py --data-dir /datasets/ljspeech --from-raw
tensorboard --logdir ./log/
To start to train from previous checkpoint, --load-step
is available.
python train.py --load-epoch 10 --config ./ckpt/l1.json
[TODO] To synthesize test set, run synth.py
.
python synth.py
[TODO] Pretrained checkpoints are relased on releases.
To use pretrained model, download files and unzip it. \ Checkout git repository to proper commit tags and following is sample script.
with open('l1.json') as f:
config = Config.load(json.load(f))
diffwave = VLBDiffWaveApp(config.model)
diffwave.restore('./l1/l1_99.ckpt')
# mel: [B, T, mel]
audio, _ = diffwave(mel, timesteps=50, key=jax.random.PRNGKey(0))