k2-fsa / sherpa

Speech-to-text server framework with next-gen Kaldi
https://k2-fsa.github.io/sherpa
Apache License 2.0
515 stars 103 forks source link

Triton streaming support for old zipformer(pruned stateless 7 streaming) #412

Open uni-saurabh-vyas opened 1 year ago

uni-saurabh-vyas commented 1 year ago

Hello,

I tried to use nvidia triton streaming configuration with pruned stateless 7 streaming model, but it seems that one input is missing to encoder "avg_cache", this seems to be added in new zipformer, and was not there in earlier conformer transducer model see for ref. https://github.com/k2-fsa/sherpa/blob/master/triton/model_repo_streaming/encoder/config.pbtxt.template

My question is what is the effort required to run this with Triton ? Also where can I find more information about "avg_cache" ?

uni-saurabh-vyas commented 1 year ago

found some information at https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py

csukuangfj commented 1 year ago

My question is what is the effort required to run this with Triton ?

@yuekaizhang Could you have a look at this issue?

uni-saurabh-vyas commented 1 year ago

Hi @csukuangfj @yuekaizhang

Here are some notes based on my understanding:

    {
      input_name: "avg_cache"
      output_name: "next_avg_cache"
      data_type: TYPE_FP16
      dims: [ ENCODER_LAYERS, ENCODER_DIM ]
      initial_state: {
       data_type: TYPE_FP16
       dims: [ ENCODER_LAYERS, ENCODER_DIM ]
       zero_data: true
       name: "initial state"
      }
    },
yuekaizhang commented 1 year ago

Hi @csukuangfj @yuekaizhang

Here are some notes based on my understanding:

    {
      input_name: "avg_cache"
      output_name: "next_avg_cache"
      data_type: TYPE_FP16
      dims: [ ENCODER_LAYERS, ENCODER_DIM ]
      initial_state: {
       data_type: TYPE_FP16
       dims: [ ENCODER_LAYERS, ENCODER_DIM ]
       zero_data: true
       name: "initial state"
      }
    },

Your understanding is correct. Currently, though we have exported onnx codes for triton in icefall, we still have not supported streaming zipformer triton recipe. You may need work to modify the https://github.com/k2-fsa/sherpa/blob/master/triton/model_repo_streaming by your own.

However, you may refer the two files to refactor our current triton export onnx for zipformer settings.

https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-zipformer-transducer-model.cc#L172

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py

uni-saurabh-vyas commented 1 year ago

I started working on it, but I am a bit confused about 1 thing.

In https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py#L291 I see you already have onnx script for streaming zipformer right?

Also according to docstring for forward pass of onnx

    - x, a tensor of shape (N, T, C); dtype is torch.float32
    - x_lens, a tensor of shape (N,); dtype is torch.int64

and it has two outputs:

    - encoder_out, a tensor of shape (N, T, C)
    - encoder_out_lens, a tensor of shape (N,)

 So from that seems like its an offline inference, and normal forward method should be called without states, but its calling  encoder_model.streaming_forward which is used for streaming

Here is what I was working on: https://gist.github.com/uni-saurabh-vyas/0e2d72ecfd2d8f834ec92a33f5b9c5f6

I got confused between self.num_encoders is different from self.num_encoder_layers when working out dims for states, because, in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py#L339C4-L339C88

for len_cache

len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1)  # B,15

You have mentioned 15 as num_encoder, but in zipformer.py, num_encoder=5 and num_encoder_layers=(2, 4, 3, 2, 4) which adds up to 15.

yuekaizhang commented 1 year ago

I started working on it, but I am a bit confused about 1 thing.

In https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py#L291 I see you already have onnx script for streaming zipformer right?

Also according to docstring for forward pass of onnx

    - x, a tensor of shape (N, T, C); dtype is torch.float32
    - x_lens, a tensor of shape (N,); dtype is torch.int64

and it has two outputs:

    - encoder_out, a tensor of shape (N, T, C)
    - encoder_out_lens, a tensor of shape (N,)

 So from that seems like its an offline inference, and normal forward method should be called without states, but its calling  encoder_model.streaming_forward which is used for streaming

Here is what I was working on: https://gist.github.com/uni-saurabh-vyas/0e2d72ecfd2d8f834ec92a33f5b9c5f6

I got confused between self.num_encoders is different from self.num_encoder_layers when working out dims for states, because, in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py#L339C4-L339C88

for len_cache

len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1)  # B,15

You have mentioned 15 as num_encoder, but in zipformer.py, num_encoder=5 and num_encoder_layers=(2, 4, 3, 2, 4) which adds up to 15.

len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 Here, we combined 5 state tensors with shape (B,2) (B,4) ... into a single tensor with shape (B,15). encoder_model.num_encoders should equal to 5.

uni-saurabh-vyas commented 1 year ago

I created a PR at https://github.com/k2-fsa/sherpa/pull/430