PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
22.16k stars 5.56k forks source link

模型有多个函数,该如何导出静态模型 #68697

Open yeyupiaoling opened 2 days ago

yeyupiaoling commented 2 days ago

请提出你的问题 Please ask your question

请问模型有多个函数,该如何导出静态模型?

例如下面这个模型,有两个函数都需要使用,导出只能导出一个函数,这时候应当如何导出模型呢?


class ConformerModel(paddle.nn.Layer):

    def get_encoder_out(self, speech: paddle.Tensor, speech_lengths: paddle.Tensor) -> \
            Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        encoder_outs, encoder_mask = self.encoder(speech,
                                                  speech_lengths,
                                                  decoding_chunk_size=-1,
                                                  num_decoding_left_chunks=-1)  # (B, maxlen, encoder_dim)
        ctc_probs = self.ctc.log_softmax(encoder_outs)
        encoder_lens = encoder_mask.squeeze(1).sum(1)
        return encoder_outs, ctc_probs, encoder_lens

    @paddle.jit.to_static
    def get_encoder_out_chunk(self,
                              speech: paddle.Tensor,
                              offset: int,
                              required_cache_size: int,
                              att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
                              cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
                              ) -> [paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        xs, att_cache, cnn_cache = self.encoder.forward_chunk(xs=speech,
                                                              offset=offset,
                                                              required_cache_size=required_cache_size,
                                                              att_cache=att_cache,
                                                              cnn_cache=cnn_cache)
        ctc_probs = self.ctc.softmax(xs)
        return ctc_probs, att_cache, cnn_cache

函数的输入input_spec:

   static_model = paddle.jit.to_static(
      self.get_encoder_out_chunk,
      input_spec=[
          paddle.static.InputSpec(shape=[1, None, self.input_size], dtype=paddle.float32),  # [B, T, D]
          paddle.static.InputSpec(shape=[1], dtype=paddle.int32),  # offset, int, but need be tensor
          paddle.static.InputSpec(shape=[1], dtype=paddle.int32),  # required_cache_size, int
          paddle.static.InputSpec(shape=[None, None, None, None], dtype=paddle.float32),  # att_cache
          paddle.static.InputSpec(shape=[None, None, None, None], dtype=paddle.float32)  # cnn_cache
      ])

  static_model = paddle.jit.to_static(
      self.get_encoder_out,
      input_spec=[
          paddle.static.InputSpec(shape=[None, None, self.input_size], dtype=paddle.float32),  # [B, T, D]
          paddle.static.InputSpec(shape=[None], dtype=paddle.int64),  # audio_length, [B]
      ])
zoooo0820 commented 1 day ago

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/jit/basic_usage_cn.html#moxingbaocunhejiazaizhuyishixiang 可以参考下3.4小节第5部分,是否能解决这个问题