p0p4k / vits2_pytorch

unofficial vits2-TTS implementation in pytorch
https://arxiv.org/abs/2307.16430
MIT License
465 stars 81 forks source link

Export to JIT script #46

Open OnceJune opened 10 months ago

OnceJune commented 10 months ago

Hi, I tried to export the model to JIT script but got this error: reversed(Tensor 0) -> (Tensor 0): Expected a value of type 'Tensor' for argument '0' but instead found type 'torch.torch.nn.modules.container.ModuleList Seems JIT doesn't not support reversed, how to solve this?

p0p4k commented 10 months ago

Does this help?

OnceJune commented 10 months ago

seems not the same issue. you can try below scripts:

import torch
import torch.nn as nn
import torch.jit as jit

class ReverseModule(nn.Module):
    def __init__(self):
        super(ReverseModule, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])

    def forward(self, x):
        reversed_layers = list(reversed(self.layers))
        for layers in reversed_layers:
            x = layer(x)
        return x

model = ReverseModule()
scripted_module = torch.jit.script(model)

Where the reversed_layers = list(reversed(self.layers)) is the same as https://github.com/p0p4k/vits2_pytorch/blob/379527df0b9181ff664b6a25e55158eb3ececfbe/models.py#L122 Torch.jit does not have support for reversed()

p0p4k commented 10 months ago

Will something like this work for flows?

import torch
import torch.nn as nn
import torch.jit as jit

class ReverseModule(nn.Module):
    def __init__(self):
        super(ReverseModule, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])
        # convert to a mutable dict
        self.layers = nn.ModuleDict({str(i): layer for i, layer in enumerate(self.layers)})
        self.reversed_layers = nn.ModuleDict()
        for i in range(len(self.layers) - 1, -1, -1):
            self.reversed_layers[str(i)] = self.layers[str(i)]

    def forward(self, x):
        for k, v in self.reversed_layers.items():
            x = v(x)
        return x

model = ReverseModule()
scripted_module = torch.jit.script(model)
EmreOzkose commented 9 months ago

Hi @OnceJune , what are your versions of pytorch and repo ? I am trying to export jit and faces lots of issues like typing, missing module issues, etc..