jik876 / hifi-gan

HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
MIT License
1.94k stars 506 forks source link

Tracing to torchscript #51

Open ctlaltdefeat opened 3 years ago

ctlaltdefeat commented 3 years ago

Has anyone been able to successfully convert the generator model to torchscript?

I receive a bizarre error: while tracing works

zero = torch.full((1, 80, 10), -11.52).cuda()
with open("hifi-gan/config.json") as f:
    data = f.read()
h = env.AttrDict(json.loads(data))
vocoder = models.Generator(h).cuda()
vocoder.load_state_dict(
    torch.load("hifi-gan/pretrained_universal/g_02500000")["generator"]
)
vocoder.remove_weight_norm()
vocoder.eval()
with torch.no_grad():
    traced_vocoder = torch.jit.trace(vocoder, zero)
    torch.jit.save(traced_vocoder, "vocoder.pth")

Trying to then load the model gives a weird error:

traced_vocoder = torch.jit.load("vocoder.pth")
/opt/conda/lib/python3.8/site-packages/torch/jit/_serialization.py in load(f, map_location, _extra_files)
    159     cu = torch._C.CompilationUnit()
    160     if isinstance(f, str) or isinstance(f, pathlib.Path):
--> 161         cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
    162     else:
    163         cpp_module = torch._C.import_ir_module_from_buffer(

RuntimeError: Found character '45' in string, strings must be qualified Python identifiers
ErenBalatkan commented 3 years ago

I were able to convert to torchscript via jit.script after some slight modifications, can share the repo tonight

epochsimate commented 3 years ago

@ErenBalatkan would appreciate it, thanks!

evrrn commented 3 years ago

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

ErenBalatkan commented 3 years ago

You can find my modified version here

I have also included a simple benchmark for comparing scripted version to PyTorch.

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

I did observe around %10 on my work laptop (cpu), %5 on my desktop on both CPU and GPU.

ctlaltdefeat commented 3 years ago

Thank you for the modified version that compiles as a scripted module, however I still receive the same error when doing torch.jit.load. The fact that it seems to work for you is a bit puzzling given that I haven't done anything special to my installation and other scripted modules load fine.

ErenBalatkan commented 3 years ago

Hmm, it works fine both on my home and work computers. I suggest trying the script with Nvidia's PyTorch docker container, it may help with your problem.

https://ngc.nvidia.com/catalog/containers/nvidia:pytorch

ctlaltdefeat commented 3 years ago

It's definitely an environment issue and/or pytorch bug, as I did confirm it working on a different set up. In any case, I'll leave this issue open here so that your modifications may be merged if the authors wish to.

Axelwickm commented 3 years ago

Have been working with TorchScript on another project, and stumbled across this issue with the exact same error message.

For me the issue was that I importing modules I was tracing with a dashes (character 45) in the paths. Maybe the dash in hifi-gan is the problem for you too? I don't know why this information is incorporated into the TorchScript binary file, but changing the path to underscores fixed it the error when loading in C++ for me.

ctlaltdefeat commented 3 years ago

Have been working with TorchScript on another project, and stumbled across this issue with the exact same error message.

For me the issue was that I importing modules I was tracing with a dashes (character 45) in the paths. Maybe the dash in hifi-gan is the problem for you too? I don't know why this information is incorporated into the TorchScript binary file, but changing the path to underscores fixed it the error when loading in C++ for me.

Thanks! I was indeed using dashes and loading using importlib, and when I instead just added those paths to the sys path the error goes away. Does seem like a weird Torchscript bug.

SarBH commented 2 years ago

Hey there. I have the same error for a different model. Even after removing the importlib uses from my code, still getting the same behavior as described in the #51: model traces without errors, and upon loading hits the error with character 45. I've also tried removing all extra_files from the trace, to no avail.

Anyone have another idea? Or can point me to the best way to debug this given that it's a cpp_module?

zhangsanfeng86 commented 2 years ago

You can find my modified version here

I have also included a simple benchmark for comparing scripted version to PyTorch.

@ErenBalatkan have you noticed any speed up for the scripted model? Thanks!

I did observe around %10 on my work laptop (cpu), %5 on my desktop on both CPU and GPU.

@ErenBalatkan Could you upload again?

exercise-book-yq commented 2 years ago

稍作修改后,我可以通过 jit.script 转换为 torchscript,今晚可以分享 repo

can you share your repo again? Thank you very much!

exercise-book-yq commented 2 years ago

你可以在这里找到我的修改版本

我还提供了一个简单的基准,用于将脚本版本与 PyTorch 进行比较。

@ErenBalatkan您是否注意到脚本模型的任何加速?谢谢!

我确实在工作笔记本电脑 (cpu) 上观察到 %10 左右,在 CPU 和 GPU 上的台式机上观察到 %5。

@ErenBalatkan 可以再上传吗?

can you share your repo again?

vionwinnie commented 1 year ago

@ErenBalatkan can you share your script again? Would be super helpful for my project. Appreciate it!