Open lonce opened 6 years ago
We don't have a specific script for it, but it wouldn't be hard to write it - you just need to create a SampleRNN
instance with the same parameters as those used for training, call load_state_dict
with the checkpoint you want to use, create a Generator
passing the model in the constructor and then call it to generate samples.
Right. Sometimes the latest checkpoint or checkpoint with lowest training loss doesn't actually generate good audio (unconditionally). It might get "trapped" in an attractor state. As part of the Dadabots process we listen to the short sequences to find a good checkpoint before generating longer sequences.
Hi,
I'm currently trying to write a script to generate audio from a saved checkpoint. Could I possibly get some more specific instructions on how to go about this?
I should clarify, I currently have a script, but I'm getting an error when loading the checkpoint. It says "Missing key(s) in state_dict: " then lists a bunch of keys, and then says "Unexpected key(s) in state_dict: " then also lists a bunch of keys.
Here is a pastebin link to the exact output (https://pastebin.com/cFNrnr7e)
EDIT: I've fixed the error, loading the checkpoint put "model." in front of everything so that had to be trimmed out:
pretrained_state = torch.load(specific_checkpoint_path)
new_pretrained_state = OrderedDict()
for k, v in pretrained_state.items():
layer_name = k.replace("model.", "")
new_pretrained_state[layer_name] = v
New issue, the .wav file that I'm generating and saving can't be played back by anything
EDIT: Disregard, I solved my issue.
My solution for this is described in the following steps (and implemented in this fork):
def main(exp, frame_sizes, dataset, **params):
params = dict(
default_params,
exp=exp, frame_sizes=frame_sizes, dataset=dataset,
**params
)
import json
with open(os.path.join(results_path, 'sample_rnn_params.json'), 'w') as fp:
json.dump(params, fp, sort_keys=True, indent=4)
...
class GeneratorPlugin(Plugin):
...
def register(self, trainer):
self.generate = Generator(trainer.model.model, trainer.cuda)
def register_generate(self, model, cuda):
self.generate = Generator(model, cuda)
...
from model import SampleRNN
import torch
from collections import OrderedDict
import os
import json
from trainer.plugins import GeneratorPlugin
# Paths
RESULTS_PATH = 'results/exp:TEST-frame_sizes:16,4-n_rnn:2-piano/'
PRETRAINED_PATH = RESULTS_PATH + 'checkpoints/best-ep65-it79430'
GENERATED_PATH = RESULTS_PATH + 'generated/'
if not os.path.exists(GENERATED_PATH):
os.mkdir(GENERATED_PATH)
# Load model parameters from .json for audio generation
params_path = RESULTS_PATH + 'sample_rnn_params.json'
with open(params_path, 'r') as fp:
params = json.load(fp)
# Create model with same parameters as used in training
model = SampleRNN(
frame_sizes=params['frame_sizes'],
n_rnn=params['n_rnn'],
dim=params['dim'],
learn_h0=params['learn_h0'],
q_levels=params['q_levels'],
weight_norm=params['weight_norm']
)
# Delete "model." from key names since loading the checkpoint automatically attaches it to the key names
pretrained_state = torch.load(PRETRAINED_PATH)
new_pretrained_state = OrderedDict()
for k, v in pretrained_state.items():
layer_name = k.replace("model.", "")
new_pretrained_state[layer_name] = v
# print("k: {}, layer_name: {}, v: {}".format(k, layer_name, np.shape(v)))
# Load pretrained model
model.load_state_dict(new_pretrained_state)
# Generate Plugin
generator = GeneratorPlugin(GENERATED_PATH, params['n_samples'], params['sample_length'], params['sample_rate'])
# Call new register function to accept the trained model and the cuda setting
generator.register_generate(model.cuda(), params['cuda'])
# Generate new audio
generator.epoch('Test')
P.S.: Thank you @kurah for the Unexpected keys error solution presented above.
@gcunhase just curious if you or anyone here has a method to supply your own "seed input" for generation. As in, I want to supply some new novel input and see what it generates from that.
Thanks for this code contribution!
Is there a way to just generate samples based on a given checkpoint without training? The Generator is buried in the trainer code and teasing it out looks daunting.
Best,