ming024 / FastSpeech2

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"
MIT License
1.84k stars 537 forks source link

How FastSpeech2 export onnx ? #98

Open youngstu opened 3 years ago

youngstu commented 3 years ago

How FastSpeech2 export onnx ? I want to export onnx format and export tflite to deploy.

jerryuhoo commented 3 years ago

I solved it. The problem that causes my error is that ONNX doesn't support torch.bucketsize(). So I rewrote the bucketsize function according to https://github.com/pytorch/pytorch/issues/7284. Add this code in model/modules.py

def bucketize(self, tensor, bucket_boundaries):
    result = torch.zeros_like(tensor, dtype=torch.int32)
    for boundary in bucket_boundaries:
        result += (tensor > boundary).int()
    return result.long()

Replace all torch.bucketsize with self.bucketsize.

For input, my code is

input_names = ['speakers', 'texts','src_lens', 'max_src_len']
output_names = ['output', 'postnet_output', 'p_predictions', 'e_predictions', 'log_d_predictions', 'd_rounded', 'src_masks', 'mel_masks', 'src_lens', 'mel_lens']
dynamic_axes = {
    "texts": {1: "texts_len"}, 
    "output": {1: "output_len"}, 
    "postnet_output": {1: "postnet_output_len"}, 
    "p_predictions": {1: "p_predictions_len"}, 
    "e_predictions": {1: "e_predictions_len"}, 
    "log_d_predictions": {1: "log_d_predictions_len"}, 
    "d_rounded": {1: "d_rounded_len"}, 
    "src_masks": {1: "src_masks_len"}
}

dummy_input_1 = batch[2]
dummy_input_2 = batch[3]
dummy_input_3 = batch[4]
dummy_input_4 = batch[5]
dummy_input_4 = torch.from_numpy(np.array(dummy_input_4)).to(device)
torch.onnx.export(model, args=(dummy_input_1, dummy_input_2, dummy_input_3, dummy_input_4), f="FastSpeech.onnx", input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
Pydataman commented 2 years ago

mark

Tian14267 commented 2 years ago

@jerryuhoo Excuse me , How did you solve the dynamic input in inference ? I also geiv the dynamic_axes , But I can't get dynamic input in inference. All the output is same with output in what I covert to Onnx model. It can't get dynamic output. Here is my detail in this Link

lucasjinreal commented 2 years ago

@jerryuhoo I got same problem as yours. Even the texts and text_lens exported as dynamic axis, but somehow it can not fully traced as dynamic, I can make it pass onnxruntime only when set input shape same as export onnx.

python -m onnxsim fastspeech2.onnx fastspeech2_sim.onnx --dynamic-input-shape --input-shape tones:1,58,8 texts:1,58 text_lens:1
Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!

so I think the solution here would be forcely padding input same as your input size and make input fixed.

But in this way, don't know how to cut the postnet output according to input text real lenght.

lucasjinreal commented 2 years ago

@jerryuhoo I got same problem as yours. Even the texts and text_lens exported as dynamic axis, but somehow it can not fully traced as dynamic, I can make it pass onnxruntime only when set input shape same as export onnx.

python -m onnxsim fastspeech2.onnx fastspeech2_sim.onnx --dynamic-input-shape --input-shape tones:1,58,8 texts:1,58 text_lens:1
Simplifying...
Checking 0/3...
Checking 1/3...
Checking 2/3...
Ok!

so I think the solution here would be forcely padding input same as your input size and make input fixed.

But in this way, don't know how to cut the postnet output according to input text real lenght.

OnceJune commented 2 years ago

@jinfagang How do you convert torch.linspace in variance predictor?I got error msg "Exporting the operator linspace to ONNX opset version 11 is not supported". My torch version is 1.7.0.

zhanminmin commented 2 years ago

@jinfagang I think the .item() function make it constant. As warning like below: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! max_len = torch.max(lengths).item()