chrisdonahue / wavegan

WaveGAN: Learn to synthesize raw audio with generative adversarial networks
MIT License
1.33k stars 282 forks source link

Can anyone help me with the generation of files #86

Closed vishal1o8 closed 4 years ago

vishal1o8 commented 4 years ago

Sorry, as I'm new to this. I am facing a hard time to generate new clips . I have successfully trained the wavegan model and ran a preview and I am satisfied with what I am hearing. Now which script should I run to get say 1000 clips from my trained wavegan model.

alexanderbarnhill commented 4 years ago

Hey @vishal1o8

This is the script that I've been using to generate samples. All you really need is the generate() method. But if you have any questions let me know.


from scipy.io.wavfile import write
import os
import tensorflow.compat.v1 as tf
import numpy as np
import json
from shutil import copy

tf.disable_eager_execution()

def generate():
    # Determine output target
    output = OUTPUT
    if not os.path.isdir(output):
        os.makedirs(output)

    copy(full_checkpoint, OUTPUT)
    full_index = os.path.join(CHECKPOINT_DIR, 'model.ckpt-{}.index'.format(CHECKPOINT))
    full_meta = os.path.join(CHECKPOINT_DIR, 'model.ckpt-{}.meta'.format(CHECKPOINT))
    copy(full_index, OUTPUT)
    copy(full_meta.format(CHECKPOINT), OUTPUT)
    parameters_output = os.path.join(output, 'parameters.json')
    with open(parameters_output, 'w') as outfile:
        json.dump(p, outfile)
    # Load checkpoint into session
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph(infer_file)
    graph = tf.get_default_graph()
    session = tf.InteractiveSession()
    saver.restore(session, checkpoint)

    # Create random latent vectors
    _z = (np.random.rand(GENERATE_COUNT, 100) * 2.) - 1

    # Synthesize G(z)
    z = graph.get_tensor_by_name('z:0')
    G_z = graph.get_tensor_by_name('G_z:0')
    _G_z = session.run(G_z, {z: _z})

    # Generate WAV files
    for i in range(GENERATE_COUNT):
        filename = "sample-{}.wav".format(i)
        sample_target = os.path.join(output, filename)
        write(sample_target, SAMPLE_RATE, _G_z[i])

if __name__ == "__main__":
    with open('parameters.json') as parameters:
        p = json.load(parameters)

    print(p)
    TARGET_INFER_DIR = p['TARGET_INFER_DIR']
    INFER_TITLE = p['INFER_TITLE']
    CHECKPOINT = p['CHECKPOINT']
    CHECKPOINT_DIR = p['CHECKPOINT_DIR']
    GENERATE_COUNT = p['GENERATE_COUNT']
    SAMPLE_RATE = p['SAMPLE_RATE']
    OUTPUT = p['OUTPUT']

    infer_file = os.path.join(TARGET_INFER_DIR, INFER_TITLE)
    checkpoint_name = "model.ckpt-{}".format(CHECKPOINT)
    full_checkpoint_name = "model.ckpt-{}.data-00000-of-00001".format(CHECKPOINT)
    checkpoint = os.path.join(CHECKPOINT_DIR, checkpoint_name)
    full_checkpoint = os.path.join(CHECKPOINT_DIR, full_checkpoint_name)
    saved_checkpoint = os.path.join(OUTPUT, full_checkpoint_name)

    assert os.path.exists(infer_file)
    assert os.path.isfile(infer_file)
    assert os.path.isdir(CHECKPOINT_DIR)
    try:
        assert os.path.isfile("{}.data-00000-of-00001".format(checkpoint))
    except AssertionError:
        print("Checkpoint {} doesn't exist".format(CHECKPOINT))
        exit(1)

    generate()
vishal1o8 commented 4 years ago

Hey @alexanderbarnhill,

I was able to figure it out. Thanks for the help !