Open xDuck opened 3 years ago
@xDuck As far as I know, jit.trace only works for models with fixed shape inputs. This model uses inputs of variable size. Does it make sense to use jit.trace? (I also tried to use jit.script(), it makes more errors...)
@xDuck @KinamSalad I think the code should be modified to enable the use of torch.jit
. It's in my future plan for the next major update.
Thank you! I have done some work on it and got it almost complete, I ended up removing the ability to do batch runs (batch size now always = 1) because I didn’t need them. I planned on going back to add it back on when I had time but got busy.
The one thing I didn’t finish was the Length Regulator module.
On Tue, Mar 23, 2021 at 9:33 PM Chung-Ming Chien @.***> wrote:
@xDuck https://github.com/xDuck @KinamSalad https://github.com/KinamSalad I think the code should be modified to enable the use of torch.jit. It's in my future plan for the next major update.
— You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub https://github.com/ming024/FastSpeech2/issues/35#issuecomment-805403946, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABICRIJ6R74HUGI4T5LSWLTTFE6OXANCNFSM4ZLIQCGA .
@xDuck Yeah I think the length regulator may be a major problem of scripting the whole model. Looking forward to your result!
I was able to find a couple hours to work at this again. Here is the updated length regulator that compiles with JIT. I now have the whole model running through JIT but I cheated by removing all of the batch stuff in favor of only supporting single mode because I'm lazy, so I won't make a PR on this repo - but the rest of the model is pretty straight forward for converting to JIT.
The other catch here is the model no longer returns mel_len
but that can be derived from the outputs that already exist.
Credit to https://github.com/rishikksh20/FastSpeech2 - I referenced their code pretty heavily in this.
@torch.jit.script
def pad_2d_tensor(xs: List[torch.Tensor], pad_value: float = 0.0):
max_len = max([xs[i].size(0) for i in range(len(xs))])
out_list = []
for i, batch in enumerate(xs):
one_batch_padded = F.pad(
batch, (0, 0, 0, max_len - batch.size(0)), "constant", pad_value
)
out_list.append(one_batch_padded)
out_padded = torch.stack(out_list)
return out_padded
@torch.jit.script
def expand(x: torch.Tensor, d: torch.Tensor):
if d.sum() == 0:
d = d.fill_(1)
out = []
for x_, d_ in zip(x, d):
if d_ != 0:
out.append(x_.repeat(int(d_), 1))
return out
@torch.jit.script
def repeat_one_sequence(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
if d.sum() == 0:
d = d.fill_(1)
out = []
for x_, d_ in zip(x, d):
if d_ != 0:
out.append(x_.repeat(int(d_), 1))
return torch.cat(out, dim=0)
@torch.jit.script
def LR(x: torch.Tensor, duration: torch.Tensor):
output = [repeat_one_sequence(x, d) for x, d in zip(x, duration)]
output = pad_2d_tensor(output, 0.0)
return output
@ming024
(Running my model in C++ I am running at about 15x realtime for the FastSpeech2 portion)
@xDuck Great job!!!! Thanks for your work! I will try it several days later!
@xDuck Great job!!!! Thanks for your work! I will try it several days later!
Thanks a lot for your repo, I am wondering if you providing the update?
Hi, thanks for the repo! I am wondering if you have plans to convert the model to be JIT-traceable for exporting to C++? I tried to JIT trace and it generated some critical warnings:
I made the following changes:
tools.py:91
and
synthesize:87
It seems like most of the issues are with
max_len
being used in conditionals and array slices. I will look into this more but wanted to see if you had tried this before