ming024 / FastSpeech2

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"
MIT License
1.69k stars 515 forks source link

Modify model to allow JIT tracing #35

Open xDuck opened 3 years ago

xDuck commented 3 years ago

Hi, thanks for the repo! I am wondering if you have plans to convert the model to be JIT-traceable for exporting to C++? I tried to JIT trace and it generated some critical warnings:

FastSpeech2/env/lib/python3.7/site-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)
FastSpeech2/utils/tools.py:97: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = max_len.detach().cpu().numpy()[0]
FastSpeech2/transformer/Models.py:82: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.training and src_seq.shape[1] > self.max_seq_len:
FastSpeech2/transformer/Models.py:90: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  :, :max_len, :
FastSpeech2/model/modules.py:186: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  expand_size = predicted[i].item()
FastSpeech2/model/modules.py:180: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return output, torch.LongTensor(mel_len).to(device)
FastSpeech2/utils/tools.py:94: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = torch.max(lengths).item()
FastSpeech2/transformer/Models.py:145: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.training and enc_seq.shape[1] > self.max_seq_len:
FastSpeech2/transformer/Models.py:154: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = min(max_len, self.max_seq_len)
FastSpeech2/transformer/Models.py:158: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  dec_output = enc_seq[:, :max_len, :] + self.position_enc[
FastSpeech2/transformer/Models.py:159: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  :, :max_len, :
FastSpeech2/transformer/Models.py:161: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  mask = mask[:, :max_len]
FastSpeech2/transformer/Models.py:162: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  slf_attn_mask = slf_attn_mask[:, :, :max_len]

I made the following changes:

tools.py:91

def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()
    else:
        print(max_len)
        max_len = max_len.detach().cpu().numpy()[0]
        print(max_len)
    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask

and

synthesize:87

def synthesize(model, step, configs, vocoder, batchs, control_values):
    preprocess_config, model_config, train_config = configs
    pitch_control, energy_control, duration_control = control_values

    for batch in batchs:
        batch = to_device(batch, device)
        with torch.no_grad():
            traced_script_module = torch.jit.trace(
                model, (batch[2], batch[3], batch[4], torch.tensor([batch[5]]))
            )
            traced_script_module.save("traced_fastspeech_model.pt")

It seems like most of the issues are with max_len being used in conditionals and array slices. I will look into this more but wanted to see if you had tried this before

KinamSalad commented 3 years ago

@xDuck As far as I know, jit.trace only works for models with fixed shape inputs. This model uses inputs of variable size. Does it make sense to use jit.trace? (I also tried to use jit.script(), it makes more errors...)

ming024 commented 3 years ago

@xDuck @KinamSalad I think the code should be modified to enable the use of torch.jit. It's in my future plan for the next major update.

xDuck commented 3 years ago

Thank you! I have done some work on it and got it almost complete, I ended up removing the ability to do batch runs (batch size now always = 1) because I didn’t need them. I planned on going back to add it back on when I had time but got busy.

The one thing I didn’t finish was the Length Regulator module.

On Tue, Mar 23, 2021 at 9:33 PM Chung-Ming Chien @.***> wrote:

@xDuck https://github.com/xDuck @KinamSalad https://github.com/KinamSalad I think the code should be modified to enable the use of torch.jit. It's in my future plan for the next major update.

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/ming024/FastSpeech2/issues/35#issuecomment-805403946, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABICRIJ6R74HUGI4T5LSWLTTFE6OXANCNFSM4ZLIQCGA .

ming024 commented 3 years ago

@xDuck Yeah I think the length regulator may be a major problem of scripting the whole model. Looking forward to your result!

xDuck commented 3 years ago

I was able to find a couple hours to work at this again. Here is the updated length regulator that compiles with JIT. I now have the whole model running through JIT but I cheated by removing all of the batch stuff in favor of only supporting single mode because I'm lazy, so I won't make a PR on this repo - but the rest of the model is pretty straight forward for converting to JIT.

The other catch here is the model no longer returns mel_len but that can be derived from the outputs that already exist.

Credit to https://github.com/rishikksh20/FastSpeech2 - I referenced their code pretty heavily in this.

@torch.jit.script
def pad_2d_tensor(xs: List[torch.Tensor], pad_value: float = 0.0):
    max_len = max([xs[i].size(0) for i in range(len(xs))])

    out_list = []

    for i, batch in enumerate(xs):
        one_batch_padded = F.pad(
            batch, (0, 0, 0, max_len - batch.size(0)), "constant", pad_value
        )
        out_list.append(one_batch_padded)

    out_padded = torch.stack(out_list)
    return out_padded

@torch.jit.script
def expand(x: torch.Tensor, d: torch.Tensor):
    if d.sum() == 0:
        d = d.fill_(1)
    out = []
    for x_, d_ in zip(x, d):
        if d_ != 0:
            out.append(x_.repeat(int(d_), 1))
    return out

@torch.jit.script
def repeat_one_sequence(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
        if d.sum() == 0:
            d = d.fill_(1)
        out = []
        for x_, d_ in zip(x, d):
            if d_ != 0:
                out.append(x_.repeat(int(d_), 1))

        return torch.cat(out, dim=0)

@torch.jit.script
def LR(x: torch.Tensor, duration: torch.Tensor):
    output = [repeat_one_sequence(x, d) for x, d in zip(x, duration)]
    output = pad_2d_tensor(output, 0.0)
    return output

@ming024

(Running my model in C++ I am running at about 15x realtime for the FastSpeech2 portion)

ming024 commented 3 years ago

@xDuck Great job!!!! Thanks for your work! I will try it several days later!

YoLi-sw commented 1 year ago

@xDuck Great job!!!! Thanks for your work! I will try it several days later!

Thanks a lot for your repo, I am wondering if you providing the update?