TensorSpeech / TensorFlowTTS

:stuck_out_tongue_closed_eyes: TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2 (supported including English, French, Korean, Chinese, German and Easy to adapt for other languages)
https://tensorspeech.github.io/TensorFlowTTS/
Apache License 2.0
3.84k stars 815 forks source link

Saved PB model error in the newest code #443

Closed ZhaoZeqing closed 3 years ago

ZhaoZeqing commented 3 years ago

I used the newest code training a Tacotron2 model with Tensorflow 2.3.1 and want to save it from .h5 model to a PB model, but got this error (Tensorflow 2.2 also got this error):

Traceback (most recent call last): File "h5_to_pb.py", line 86, in h5_to_pb.h5_to_pb_tacotron2(tacotron2_config, tacotron2_checkpoint, tacotron2_pb_path) File "h5_to_pb.py", line 65, in h5_to_pb_tacotron2 tf.saved_model.save(tacotron2, pb_path, signatures=tacotron2.inference) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py", line 976, in save obj, export_dir, signatures, options, meta_graph_def) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py", line 1051, in _build_meta_graph signature_serialization.validate_saveable_view(checkpoint_graph_view) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py", line 268, in validate_saveable_view saveable_view.root): File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py", line 110, in list_dependencies extra_dependencies = self.list_extra_dependencies(obj) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py", line 139, in list_extra_dependencies self._serialization_cache) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 3015, in _list_extra_dependencies_for_serialization .list_extra_dependencies_for_serialization(serialization_cache)) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 74, in list_extra_dependencies_for_serialization return self.objects_to_serialize(serialization_cache) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 75, in objects_to_serialize serialization_cache).objects_to_serialize) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 95, in _get_serialized_attributes serialization_cache) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 51, in _get_serialized_attributes_internal default_signature = save_impl.default_save_signature(self.obj) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 205, in default_save_signature fn.get_concrete_function() File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 1167, in get_concrete_function concrete = self._get_concrete_function_garbage_collected(*args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 1073, in _get_concrete_function_garbage_collected self._initialize(args, kwargs, add_initializers_to=initializers) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize *args, *kwds)) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected graphfunction, , _ = self._maybe_define_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function capture_by_value=self._capture_by_value), File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func func_outputs = python_func(func_args, func_kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn return weak_wrapped_fn().wrapped(*args, kwds) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saving_utils.py", line 134, in _wrapped_model outputs = model(inputs, training=False) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 985, in call outputs = call_fn(inputs, *args, *kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper return func(args, kwargs) TypeError: call() missing 4 required positional arguments: 'input_lengths', 'speaker_ids', 'mel_gts', and 'mel_lengths'

Here is my code:

config = AutoConfig.from_pretrained(config)
model = TFAutoModel.from_pretrained(
    config=config,
    pretrained_path=h5_model,
    is_build=True
)
char_ids = [[1,2,3,4,5,6,7,8]]
(
    mel_outputs,
    post_mel_outputs,
    stop_outputs,
    alignment_historys,
) = tacotron2.inference(
    input_ids=char_ids,
    input_lengths=[len(char_ids[0])],
    speaker_ids=[0],
)
tacotron2.load_weights(h5_model)
tf.saved_model.save(tacotron2, pb_path, signatures=tacotron2.inference)
dathudeptrai commented 3 years ago

@ZhaoZeqing https://github.com/TensorSpeech/TensorFlowTTS/issues/418#issuecomment-739887308

ZhaoZeqing commented 3 years ago

@dathudeptrai It worked! Thanks!

dathudeptrai commented 3 years ago

@ZhaoZeqing i created a pull request to fix this problem :D #446

ZhaoZeqing commented 3 years ago

@dathudeptrai Many thanks! And is there any suggestion about add speaker embedding (i-vector, d-vector...) into Tacotron2 and Fastspeech2?