kan-bayashi / ParallelWaveGAN

Unofficial Parallel WaveGAN (+ MelGAN & Multi-band MelGAN & HiFi-GAN & StyleMelGAN) with Pytorch
https://kan-bayashi.github.io/ParallelWaveGAN/
MIT License
1.54k stars 339 forks source link

how to convert model to torchscript? #382

Open zhuziying opened 1 year ago

zhuziying commented 1 year ago

import sys sys.path.insert(1,'/root/Downloads/ParallelWaveGAN-0.5.3/parallel_wavegan/utils') import torch import utils module = utils.load_model('pretrained_model/checkpoint-400000steps.pkl') print(module)

model = torch.load('pretrained_model/checkpoint-400000steps.pkl',map_location=torch.device('cpu'))

print('load model successful!')

x = torch.zeros(5, 10, 5, dtype=torch.float64) x = x + (0.1*0.5)torch.randn(5, 10, 5) c = torch.rand(80,80,5) print(x) print('-------------------') print(c) print('-------------------') print(x.size(-1)) print('-------------------') print(c.size(-1)) trace_model = torch.jit.trace(module,(x,c))

error is : Traceback (most recent call last): File "demo.py", line 19, in trace_model = torch.jit.trace(module,(x,c)) File "/root/anaconda3/envs/pwgan/lib/python3.7/site-packages/torch/jit/_trace.py", line 768, in trace _module_class, File "/root/anaconda3/envs/pwgan/lib/python3.7/site-packages/torch/jit/_trace.py", line 983, in trace_module argument_names, File "/root/anaconda3/envs/pwgan/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, *kwargs) File "/root/anaconda3/envs/pwgan/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1178, in _slow_forward result = self.forward(input, **kwargs) File "/root/Downloads/ParallelWaveGAN-0.5.3/parallel_wavegan/models/parallel_wavegan.py", line 159, in forward assert c.size(-1) == x.size(-1) AssertionError

how to set parametr x and c value?