Closed vishal1o8 closed 4 years ago
Hey @vishal1o8
This is the script that I've been using to generate samples. All you really need is the generate()
method. But if you have any questions let me know.
from scipy.io.wavfile import write
import os
import tensorflow.compat.v1 as tf
import numpy as np
import json
from shutil import copy
tf.disable_eager_execution()
def generate():
# Determine output target
output = OUTPUT
if not os.path.isdir(output):
os.makedirs(output)
copy(full_checkpoint, OUTPUT)
full_index = os.path.join(CHECKPOINT_DIR, 'model.ckpt-{}.index'.format(CHECKPOINT))
full_meta = os.path.join(CHECKPOINT_DIR, 'model.ckpt-{}.meta'.format(CHECKPOINT))
copy(full_index, OUTPUT)
copy(full_meta.format(CHECKPOINT), OUTPUT)
parameters_output = os.path.join(output, 'parameters.json')
with open(parameters_output, 'w') as outfile:
json.dump(p, outfile)
# Load checkpoint into session
tf.reset_default_graph()
saver = tf.train.import_meta_graph(infer_file)
graph = tf.get_default_graph()
session = tf.InteractiveSession()
saver.restore(session, checkpoint)
# Create random latent vectors
_z = (np.random.rand(GENERATE_COUNT, 100) * 2.) - 1
# Synthesize G(z)
z = graph.get_tensor_by_name('z:0')
G_z = graph.get_tensor_by_name('G_z:0')
_G_z = session.run(G_z, {z: _z})
# Generate WAV files
for i in range(GENERATE_COUNT):
filename = "sample-{}.wav".format(i)
sample_target = os.path.join(output, filename)
write(sample_target, SAMPLE_RATE, _G_z[i])
if __name__ == "__main__":
with open('parameters.json') as parameters:
p = json.load(parameters)
print(p)
TARGET_INFER_DIR = p['TARGET_INFER_DIR']
INFER_TITLE = p['INFER_TITLE']
CHECKPOINT = p['CHECKPOINT']
CHECKPOINT_DIR = p['CHECKPOINT_DIR']
GENERATE_COUNT = p['GENERATE_COUNT']
SAMPLE_RATE = p['SAMPLE_RATE']
OUTPUT = p['OUTPUT']
infer_file = os.path.join(TARGET_INFER_DIR, INFER_TITLE)
checkpoint_name = "model.ckpt-{}".format(CHECKPOINT)
full_checkpoint_name = "model.ckpt-{}.data-00000-of-00001".format(CHECKPOINT)
checkpoint = os.path.join(CHECKPOINT_DIR, checkpoint_name)
full_checkpoint = os.path.join(CHECKPOINT_DIR, full_checkpoint_name)
saved_checkpoint = os.path.join(OUTPUT, full_checkpoint_name)
assert os.path.exists(infer_file)
assert os.path.isfile(infer_file)
assert os.path.isdir(CHECKPOINT_DIR)
try:
assert os.path.isfile("{}.data-00000-of-00001".format(checkpoint))
except AssertionError:
print("Checkpoint {} doesn't exist".format(CHECKPOINT))
exit(1)
generate()
Hey @alexanderbarnhill,
I was able to figure it out. Thanks for the help !
Sorry, as I'm new to this. I am facing a hard time to generate new clips . I have successfully trained the wavegan model and ran a preview and I am satisfied with what I am hearing. Now which script should I run to get say 1000 clips from my trained wavegan model.