PaddlePaddle / Parakeet

PAddle PARAllel text-to-speech toolKIT (supporting Tacotron2, Transformer TTS, FastSpeech2/FastPitch, SpeedySpeech, WaveFlow and Parallel WaveGAN)
Other
598 stars 83 forks source link

vocoder predict 报错 #103

Closed milkliker closed 3 years ago

milkliker commented 3 years ago

按照文档 https://paddle-parakeet.readthedocs.io/en/latest/basic.html#text-to-spectrogram

import soundfile as df
from parakeet.models import ConditionalWaveFlow

# load the pretrained model
checkpoint_dir = Path("waveflow_pretrained")
config = yacs.config.CfgNode.load_cfg(str(checkpoint_dir / "config.yaml"))
checkpoint_path = str(checkpoint_dir / "step-2000000")
vocoder = ConditionalWaveFlow.from_pretrained(config, checkpoint_path)
vocoder.eval()

# synthesize
audio = vocoder.predict(mel_output)
sf.write(audio_path, audio, config.data.sample_rate)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/functional/conv.py in _conv_nd(x, weight, bias, stride, padding, padding_algorithm, dilation, groups, data_format, channel_dim, op_type, use_cudnn, use_mkldnn, name) 117 "padding_algorithm", padding_algorithm, "data_format", 118 data_format) --> 119 pre_bias = getattr(core.ops, op_type)(x, weight, *attrs) 120 if bias is not None: 121 out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) *ValueError: (InvalidArgument) The number of input's channels should be equal to filter's channels groups for Op(Conv). But received: the input's channels is 764, the input's shape is [1, 764, 1, 1263]; the filter's channels is 80, the filter's shape is [256, 80, 1, 1]; the groups is 1, the data_format is NCHW. The error may come from wrong data_format setting. [Hint: Expected input_channels == filter_dims[1] groups, but received input_channels:764 != filter_dims[1] groups:80.] (at /paddle/paddle/fluid/operators/conv_op.cc:96)** [operator < conv2d > error]

milkliker commented 3 years ago

搞定了,mel_output需要transpose,建议修改一下官方文档

# synthesize
audio = vocoder.predict(mel_output.T)
sf.write(audio_path, audio, config.data.sample_rate)