revsic / jax-variational-diffwave

Jax/Flax implementation of Variational-DiffWave.
MIT License
40 stars 6 forks source link

Evaluation code #2

Open matrix-alpha opened 2 years ago

matrix-alpha commented 2 years ago

Hi, thanks for implementation. Can you provide the evaluation code?

revsic commented 2 years ago

Sample inference code is here, or did you mean KDSD, FDSD-like evaluations ?

with open('l1.json') as f:
    config = Config.load(json.load(f))

diffwave = VLBDiffWaveApp(config.model)
diffwave.restore('./l1/l1_99.ckpt')

# mel: [B, T, mel]
audio, _ = diffwave(mel, timesteps=50, key=jax.random.PRNGKey(0))