espnet / espnet_onnx

Onnx wrapper for espnet infrernce model
MIT License
149 stars 24 forks source link

Question on stream_asr.end() function for streaming asr #76

Open espnetUser opened 1 year ago

espnetUser commented 1 year ago

Hi @Masao-Someki,

In the readme the example for streaming asr shows the use of start() and end() methods:

from espnet_onnx import StreamingSpeech2Text

stream_asr = StreamingSpeech2Text(tag_name)

# start streaming asr
stream_asr.start()
while streaming:
  wav = <some code to get wav>
  assert len(wav) == stream_asr.hop_size
  stream_text = stream_asr(wav)[0][0]

# You can get non-streaming asr result with end function
nbest = stream_asr.end()

In a real streaming scenario should the start() and end() methods be called whenever the microphone is opened and closed?

I am asking because I noticed that the end() function in https://github.com/espnet/espnet_onnx/blob/master/espnet_onnx/asr/asr_streaming.py#151 calls the self.batch_beam_search() function which will restart decoding from postion 0 again causing a rather large delay for longer speech inputs. If I change https://github.com/espnet/espnet_onnx/blob/master/espnet_onnx/asr/asr_streaming.py#151 to use self.beam_search() method instead it avoids decoding the entire utterance at the end again and thus the delay.

Could you please clarify why self.batch_beam_search() is used in stream_asr.end() function?

Thanks!

espnetUser commented 1 year ago

Hi @Masao-Someki,

I am seeing a similar problem as @sanjuktasr with poor performance of my espnet_onnx model when compared to the espnet2 pytorch version.

I am focusing only on the streaming encoder part though and noticed that the encoder outputs are quite different between onnx and pytorch models. I went through this issue and followed some suggestions but so far nothing helped to resolve the problem.

Based on your discussion with @sanjuktasr I started debugging this in more detail following your list of points above and found the following:

If you compute a long sequence with your encoder, is the onnx output and PyTorch output the same?

Interestingly, the onnx and pytorch outputs do match exactly for the first interation but then start to drift apart quickly: MicrosoftTeams-image (3) This is the encoder output over time for one of the 512 outputs and it matches exactly the pytorch output up to frame 45 which is the end of the first iteration chunk of input data (https://github.com/espnet/espnet_onnx/blob/master/espnet_onnx/asr/asr_streaming.py#L138-L141).

Is the stft output the same as with torch.stft?

Because for the first chunk encoder outputs between espnet2 pytroch and onnx match stft can be ruled out as a cause here, right?

Since the simulation script of espnet_onnx will compute the encoder block incrementally, and the espnet script will compute at once, please confirm that the encoder output is the same.

I am using the streaming encoder together with the espnet2 script (https://github.com/espnet/espnet/blob/master/espnet2/bin/asr_inference_streaming.py) to extract the encoder output for the pytorch model. From my understanding this script does not compute the outputs at once but processing is done block-wise, so stft is not applied to entire waveform at once but chunkwise. There is some trimming code to handle stft padding effects (https://github.com/espnet/espnet/blob/master/espnet2/bin/asr_inference_streaming.py#L256-L281) which I don't see in espnet_onnx streaming code.

Because the first iteration outputs match between espnet_onnx and espnet2 I am thinking the differences must somehow come from the different audio chunking/buffering/trimming code between espnet_onnx and espnet streaming scripts.

I would welcome and appreciate any pointers how to match the encoder outputs between espnet_onnx and espnet2 pytorch models.

Thanks!

Masao-Someki commented 1 year ago

Hi @espnetUser, thank you for your comment. Since the first iteration is completely matched, I think this problem is related to STFT or other front-end-related parts. From your figure, it seems that the Pytorch line shifts to the right while the shape of these lines is similar. So I think this is caused by the padding or trimming part, as you mentioned. I'm still unsure of the cause, but if the _extract_feats function uses torch.stft with the default padding settings, then it might be the cause of this problem.

espnetUser commented 1 year ago

Hi @Masao-Someki, thank you very much for your prompt reply.

So I think this is caused by the padding or trimming part, as you mentioned.

Over the last days I have been working on comparing different frontends from espnet and espnet-onnx in order to determine if differences can be explained by padding/trimming parts.

Here is a snapshot for filterbank channel 3 over time (frames) for the original streaming espnet (PYTORCH-FBANK) and original espnet_onnx (ONNX_ORG_FBANK) as well as modified espnet_onnx (ONNX_MOD_FBANK) where I replaced feats, feat_length = self.frontend(speech, speech_length) with espnet method apply_frontend() from https://github.com/espnet/espnet/blob/master/espnet2/bin/asr_inference_streaming.py#L203

image

The figure shows that there is indeed a difference between frontend features of espnet and espnet_onnx after initial_wav_length that can be explained due to trimming/padding effects. To eliminate any mismatch in filterbank features I replaced the espnet_onnx frontend code with espnet apply_frontend() method and then the frontend features match exactly (see green and orange curves).

However, there is still the shift in the enoder outputs after first iteration even when the filterbank input features to the encoder match exactly:

image

So I am thinking there is some internal mismatch in how the encoder buffers/processes the chunks that leads to a mismatch/shift in encoder outputs ...

sanjuktasr commented 1 year ago

Check the encoder states. AS far as I remember there was issue in next_state variable value. The enc out mismatch was therefore starting from the 2nd instant

espnetUser commented 1 year ago

@sanjuktasr: Thanks for your reply.

AS far as I remember there was issue in next_state variable value.

encoder_out, next_states = self.forward_encoder(feats, states)

Would this be the right place to look for computation of next_state variable?

https://github.com/espnet/espnet_onnx/blob/master/espnet_onnx/export/asr/models/encoders/contextual_block_xformer.py#L61-L160

espnetUser commented 1 year ago

@sanjuktasr: Do you remember which entry in next_states dict was causing the mismatch?

espnetUser commented 1 year ago

@Masao-Someki, @sanjuktasr: I checked the streaming espnet_onnx encoder code for anything that looks suspicious and found there is a "-1" in the res_frame_num calculation in this line here:

https://github.com/espnet/espnet_onnx/blob/master/espnet_onnx/export/asr/models/encoders/contextual_block_xformer.py#L109

which is not part of the espnet2 streaming encoder code:

https://github.com/espnet/espnet/blob/master/espnet2/asr/encoder/contextual_block_conformer_encoder.py#L487

After removing the "-1" the encoder outputs looked much more in line with my espnet2 pytorch model:

image

@Masao-Someki: Could you please double-check the "-1" for the res_frame_num calculation in espnet_onnx encoder code?

Masao-Someki commented 1 year ago

@espnetUser Thank you, I didn't notice this point! I think you are correct. It should be implemented in the same manner as line 94, so -1 is incorrect.

espnetUser commented 11 months ago

Thanks for confirming @Masao-Someki.

Two follow up questions:

  1. Should I prepare a PR to fix this?
  2. I noticed that new Espnet release now implements onnx-convertible make_pad_mask function and there is WIP https://github.com/espnet/espnet_onnx/pull/89. Any timeline for when espnet_onnx will support direct export of ESPnet encoders to onnx format?

Thanks!

Masao-Someki commented 10 months ago

Sorry for the late replay, @espnetUser

  1. It's nice to have a PR for this, but since this bugfix is tiny, I will include this in #96
  2. I started implementing this in #96, and everything should be done the next weekend. Since this change is based on the new make_pad_mask and is incompatible with the past espnet versions, I need to be careful about the version conflicts of dependencies.
espnetUser commented 10 months ago

@Masao-Someki: Thank you for replying and the update on https://github.com/espnet/espnet_onnx/pull/96. Looking forward trying it out soon! :)