NVIDIA / flowtron

Flowtron is an auto-regressive flow-based generative network for text to speech synthesis with control over speech variation and style transfer
https://nv-adlr.github.io/Flowtron
Apache License 2.0
887 stars 177 forks source link

Problems with flowtron_libritts2p3k #126

Open AI-Guru opened 3 years ago

AI-Guru commented 3 years ago

Hello!

I am very happy with Flowtron! I tested the flowtron_ljs model and it worked like a charm!

But when using the other two models, I run into problems. For example flowtron_libritts2p3k.

python inference.py -c config.json -f models/flowtron_libritts2p3k.pt -w models/waveglow_256channels_universal_v5.pt -t "I’ve seen things you people wouldn’t believe." -i 0

Gives me this output:

Traceback (most recent call last):
  File "inference.py", line 130, in <module>
    infer(args.flowtron_path, args.waveglow_path, args.output_dir, args.text,
  File "inference.py", line 53, in infer
    state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
KeyError: 'state_dict'

Any ideas?

rafaelvalle commented 3 years ago

take a look at the keys for flowtron_libritts2p3k. i think we saved the model instead of just the state dict. the code below should solve your issue.

ck = torch.load(flowtron_path, map_location='cpu')
if 'model' in ck:
  state_dict = ck['model'].state_dict()
else:
    state_dict = ck['state_dict']
Syed044 commented 3 years ago

KeyError Traceback (most recent call last)

in 1 model_path = "models/model_10100" ----> 2 state_dict = torch.load(model_path, map_location='cpu')['state_dict'] 3 model = Flowtron(**model_config) 4 model.load_state_dict(state_dict) 5 _ = model.eval().cuda() KeyError: 'state_dict' what do i do about this error? and Please rafael or anyone who drop in to help, I am a newbee so please explain in easy way.
AI-Guru commented 3 years ago
# load flowtron
    model = Flowtron(**model_config).cuda()
    checkpoint = torch.load(flowtron_path, map_location='cpu')
    if 'model' in checkpoint:
        state_dict = checkpoint['model'].state_dict()
    else:
        state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    model.cuda()
    print("Loaded checkpoint '{}')" .format(flowtron_path))

This should help!