k2-fsa / sherpa

Speech-to-text server framework with next-gen Kaldi
https://k2-fsa.github.io/sherpa
Apache License 2.0
518 stars 104 forks source link

Hi, wanna consult a dimension not right question #136

Open lucasjinreal opened 1 year ago

lucasjinreal commented 1 year ago

Hi, I try to understand lstm transducer completely, I have successfully migrated your model to a single file, and also able to load the pretrained model, this is the model I got:

lstm-transducer-librispeech-stateless (but with stateless2 weights, I didn't find stateless pertained model)

Transducer(
  (encoder): RNN(
    (encoder_embed): Conv2dSubsampling(
      (conv): Sequential(
        (0): ScaledConv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
        (1): ActivationBalancer()
        (2): DoubleSwish()
        (3): ScaledConv2d(8, 32, kernel_size=(3, 3), stride=(2, 2))
        (4): ActivationBalancer()
        (5): DoubleSwish()
        (6): ScaledConv2d(32, 128, kernel_size=(3, 3), stride=(2, 2))
        (7): ActivationBalancer()
        (8): DoubleSwish()
      )
      (out): ScaledLinear(in_features=2304, out_features=512, bias=True)
      (out_norm): BasicNorm()
      (out_balancer): ActivationBalancer()
    )
    (encoder): RNNEncoder(
      (layers): ModuleList(
        (0): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (2): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (3): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (4): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (5): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (6): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (7): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (8): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (9): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (10): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (11): RNNEncoderLayer(
          (lstm): ScaledLSTM(512, 1024, proj_size=512)
          (feed_forward): Sequential(
            (0): ScaledLinear(in_features=512, out_features=2048, bias=True)
            (1): ActivationBalancer()
            (2): DoubleSwish()
            (3): Dropout(p=0.1, inplace=False)
            (4): ScaledLinear(in_features=2048, out_features=512, bias=True)
          )
          (norm_final): BasicNorm()
          (balancer): ActivationBalancer()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (decoder): Decoder(
    (embedding): ScaledEmbedding(500, 512, padding_idx=0)
    (conv): ScaledConv1d(512, 512, kernel_size=(2,), stride=(1,), groups=512, bias=False)
  )
  (joiner): Joiner(
    (encoder_proj): ScaledLinear(in_features=512, out_features=512, bias=True)
    (decoder_proj): ScaledLinear(in_features=512, out_features=512, bias=True)
    (output_linear): ScaledLinear(in_features=512, out_features=500, bias=True)
  )
  (simple_am_proj): ScaledLinear(in_features=512, out_features=500, bias=True)
  (simple_lm_proj): ScaledLinear(in_features=512, out_features=500, bias=True)
)
11:11:04 09.23 INFO demo_file.py:112]: asr model loaded!

Just load a librispeech wav to inference, got error:

scaling.py", line 369, in _conv_forward
    return F.conv2d(
RuntimeError: Given groups=1, weight of size [8, 1, 3, 3], expected input[1, 483, 1, 80] to have 1 channels, but got 483 channels instead

root reason::

transducer.py", line 63, in run_encoder
    encoder_out, _ = self.encoder(features, x_lens, states)

do u know why? ]

the features like this:

logging.info("Constructing Fbank computer")
  opts = kaldifeat.FbankOptions()
  opts.device = "cpu"
  opts.frame_opts.dither = 0
  opts.frame_opts.snip_edges = False
  opts.frame_opts.samp_freq = sample_rate
  opts.mel_opts.num_bins = 80
  fbank = kaldifeat.Fbank(opts)
  logging.info("FBank feat will run on CPU.")

  logging.info(f"Reading sound files: {sound_file}")
  wave_samples = read_sound_files(
      filenames=[sound_file],
      expected_sample_rate=sample_rate,
  )[0]

  logging.info("Decoding started")
  features = fbank(wave_samples)

shape:

[features]:  torch.Size([483, 80]) cpu torch.float32

DO u know why?

lucasjinreal commented 1 year ago

Just realize I should expand batch dim for single file. But I got a new error after fix this:

    assert x.size(0) == lengths.max().item()
AssertionError
image

why this assert happen?

lucasjinreal commented 1 year ago

From the comment of Transducer function, the x_lens is (N,) which is (1,) here, since my batch size is 1. why got above error?

Logged out torch.Size([119, 1, 512]) tensor([-1]) x.shape and lengths

csukuangfj commented 1 year ago

lstm-transducer-librispeech-stateless (but with stateless2 weights, I didn't find stateless pertained model)

You can find all pre-trained models in lstm-transducer-librispeech-stateless (but with stateless2 weights, I didn't find stateless pertained model)

For instance,

lucasjinreal commented 1 year ago

@csukuangfj do u know why I got above error?

csukuangfj commented 1 year ago

11:11:04 09.23 INFO demo_file.py:112]: asr model loaded!

Could you show your demo_file.py ?

lucasjinreal commented 1 year ago

@csukuangfj I put all code here: https://github.com/jinfagang/aural/blob/master/demo_file.py I am not get used to Kaldi like file organize so I reconstructed a little bit. The model should be same, the weight can load successfully. But inference on single wav file not right. Please help me have a look

csukuangfj commented 1 year ago

https://github.com/jinfagang/aural/blob/master/demo_file.py#L101

 features = features.unsqueeze(0)

https://github.com/jinfagang/aural/blob/master/demo_file.py#L124

    encoder_out, encoder_out_lens, hx, cx = asr_model.run_encoder(features, states)

https://github.com/jinfagang/aural/blob/master/aural/modeling/meta_arch/transducer.py#L62

x_lens = torch.tensor([features.size(0)], dtype=torch.long)

The above line is not right. You should use features.size(1).

csukuangfj commented 1 year ago

From the comment of Transducer function, the x_lens is (N,) which is (1,)

It means x_lens.shape is (N,).

lucasjinreal commented 1 year ago

@csukuangfj I copied code from ncnn sherpa demo, is that because of ncnn didn't care about batch so you squeeze all of them? I noticed that this should also be changed:

 states = (
        torch.zeros(num_encoder_layers, 1, d_model),
        torch.zeros(
            num_encoder_layers,
            1,
            rnn_hidden_size,
        ),
    )

is that 1 should be batch size?

csukuangfj commented 1 year ago

Also, https://github.com/jinfagang/aural/blob/master/demo_file.py#L115

    states = (
        torch.zeros(num_encoder_layers, d_model),
        torch.zeros(
            num_encoder_layers,
            rnn_hidden_size,
        ),
    )

This is not right. You omit the batch_size dimension.

Please use https://github.com/jinfagang/aural/blob/master/aural/modeling/encoders/rnn.py#L258

def get_init_states
csukuangfj commented 1 year ago

@csukuangfj I copied code from ncnn sherpa demo, is that because of ncnn didn't care about batch so you squeeze all of them?

Yes, ncnn does not support batch size, so I removed it.

Please refer to https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py and https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py

if you want to support batch_size.

csukuangfj commented 1 year ago

my encoder output: [encoder_out]: torch.Size([1, 119, 512]) cpu torch.float32

It means 1==N, 119==T, 512==dim

So

T = encoder_out.size(0)

this line is not right. T= encoder_out.size(1).


 hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32)  # (1, context_size)

From your code, decode_input.shape is (context_size,). Please reshape it to (1, context_size).


encoder_out_t = encoder_out[t]

Please also consider the batch size.


I suggest that you have a look at https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py#L470

lucasjinreal commented 1 year ago

@csukuangfj Hi, I just changed into:

for t in range(T):
        encoder_out_t = encoder_out[:,t,:].unsqueeze(1)
        print_shape(encoder_out_t)
        joiner_out = model.run_joiner(encoder_out_t, decoder_out)
        #  print(joiner_out.shape) # [500]
        y = joiner_out.argmax(dim=0).tolist()
        if y != blank_id:
            hyp.append(y)
            decoder_input = hyp[-context_size:]
            decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
            decoder_out = model.run_decoder(decoder_input).squeeze(0)
    return hyp[context_size:]

I got error:

  assert encoder_out.ndim in (2, 4)
AssertionError
lucasjinreal commented 1 year ago

Since decoderout is [decoder_out]: torch.Size([1, 2, 512]) cpu torch.float32 decoder.ndim should equal to encoder.ndim, so why is should in [2, 4]?

csukuangfj commented 1 year ago

decoder_out = model.run_decoder(decoder_input).squeeze(0)

Since decoderout is [decoder_out]: torch.Size([1, 2, 512])

Please use need_pad=False for the decoder.

csukuangfj commented 1 year ago

decoder.ndim should equal to encoder.ndim, so why is should in [2, 4]?

It is used to catch errors like the above one.

csukuangfj commented 1 year ago
            decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
            decoder_out = model.run_decoder(decoder_input).squeeze(0)

Please make sure the input for run_decoder is of shape (N, U). That is, it should be a 2-D tensor.

csukuangfj commented 1 year ago
decoder_out = model.run_decoder(decoder_input).squeeze(0)

the code for ncnn does not support batch size. that is why it uses squeeze(0) here. You will get errors if your code needs to support batch size.

lucasjinreal commented 1 year ago

@csukuangfj Hi, I now can forward decoder, but this for loop seems not work:

def greedy_search(model, encoder_out: torch.Tensor):
    print_shape(encoder_out)
    assert encoder_out.ndim == 3
    T = encoder_out.size(1)
    context_size = 2
    blank_id = 0  # hard-code to 0
    hyp = [blank_id] * context_size
    decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)  # (1, context_size)
    print_shape(decoder_input)
    decoder_out = model.run_decoder(decoder_input).squeeze(1)
    print_shape(decoder_out, encoder_out)
    #  print(decoder_out.shape)  # (512,)
    for t in range(T):
        encoder_out_t = encoder_out[:,t,:]
        print_shape(encoder_out_t)
        joiner_out = model.run_joiner(encoder_out_t, decoder_out)
        print(joiner_out.shape) # [500]
        y = joiner_out.argmax(dim=1)
        # how to do with batch?
        print(y)
        if y != blank_id:
            hyp.append(y)
            decoder_input = hyp[-context_size:]
            print(decoder_input)
            decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
            decoder_out = model.run_decoder(decoder_input)
    return hyp[context_size:]

how should I adopt it to support batch?

lucasjinreal commented 1 year ago

I got all 0 in argmax, is that normal?

[[encoder_out]:  torch.Size([1, 119, 512]) cpu torch.float32
[decoder_input]:  torch.Size([1, 2]) cpu torch.int32
[decoder_out]:  torch.Size([1, 512]) cpu torch.float32
[encoder_out]:  torch.Size([1, 119, 512]) cpu torch.float32
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([0])
[encoder_out_t]:  torch.Size([1, 512]) cpu torch.float32
torch.Size([1, 500])
y tensor([224])
[0, tensor([224])]
Traceback (most recent call last):
  File "/Users/xx/work/codes/cv/aural/demo_file.py", line 139, in <module>
    main()
  File "/Users/xx/work/codes/cv/aural/demo_file.py", line 133, in main
    hyp = greedy_search(asr_model, encoder_out)
  File "/Users/xx/work/codes/cv/aural/demo_file.py", line 76, in greedy_search
    decoder_out = model.run_decoder(decoder_input)
  File "/Users/xx/work/codes/cv/aural/aural/modeling/meta_arch/transducer.py", line 71, in run_decoder
    out = self.decoder(x, need_pad=False)
  File "/Users/xx/miniforge3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/xx/work/codes/cv/aural/aural/modeling/decoders/decoder.py", line 106, in forward
    embedding_out = embedding_out.permute(0, 2, 1)
RuntimeError: number of dims don't match in permute
lucasjinreal commented 1 year ago
image

I have got right result now.

May I ask when will lstm transducer code will open source in icefall?

csukuangfj commented 1 year ago

May I ask when will lstm transducer code will open source in icefall?

The code is already open-sourced in icefall.

Please see

lucasjinreal commented 1 year ago

@csukuangfj Hi, I mean wenet Chinese model. I might also need bpe model responding.

csukuangfj commented 1 year ago

@csukuangfj Hi, I mean wenet Chinese model. I might also need bpe model responding.

Maybe in the next week. I am still training the model.

lucasjinreal commented 1 year ago

@csukuangfj Hoping for it. BTW, does this func def greedy_search_single_batch( model, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]:

Have a suggested value for max_sym_epr_frame?

csukuangfj commented 1 year ago

If the model is trained using pruned RNN-T, we suggest always using max-sym-per-frame==1.

If you use a different value for max-sym-per-frame, it won't be able to support batch processing and it slows down the decoding speed significantly. Furthermore, a larger value for max-sym-per-frame has little impact on WER, so there is no reason to not use 1 for max-sym-per-frame.

lucasjinreal commented 1 year ago

@csukuangfj thank u.

BTW, how does these model exported separately? I only saw the logic is exportying the whole model rather than 3 parts.

image

the export jit here is save whole model.

image
csukuangfj commented 1 year ago

Every model folder has its own export.py.

I am not sure which export.py you are referring to.

For the lstm-transducer model, the export.py is https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export.py#L158


It has 3 parts because PyTorch does not support applying torch.jit.script() to LSTM. See https://github.com/k2-fsa/icefall/pull/479#issuecomment-1209274159 for more details.

lucasjinreal commented 1 year ago

@csukuangfj thank u. I now can sucessfully export.

However, I want further export to onnx for onnxruntime inference, but I got a errror:

 raise symbolic_registry.UnsupportedOperatorError(
torch.onnx.symbolic_registry.UnsupportedOperatorError: Exporting the operator ::resolve_conj to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

why it introduce a resolve_conj op here? It shouldn't happen.

Here is my export: https://github.com/jinfagang/aural/blob/master/export.py

csukuangfj commented 1 year ago

For the lstm-transducer model, I am afraid you cannot export it via onnx.

The reason is that we are using LSTM with projection. However, onnx does not support LSTM with projections, I think.

csukuangfj commented 1 year ago

why it introduce a resolve_conj op here? It shouldn't happen.

torch.onnx.export supports passing opset. You can try a different opset and see how the error message changes.

lucasjinreal commented 1 year ago

@csukuangfj no, onnx support lstm very well. This operation come from some other operation maybe.

the pnnx also will export this op, this ops is very weried,

csukuangfj commented 1 year ago

You will meet the following error: https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4235

return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input)

the pnnx also will export this op

You can use my modified version of pnnx. Please see https://k2-fsa.github.io/icefall/recipes/librispeech/lstm_pruned_stateless_transducer.html#export-model-for-ncnn

csukuangfj commented 1 year ago

Note note that ncnn does not support LSTM with projections.

lucasjinreal commented 1 year ago

@csukuangfj You should using opset higher than 12. After that lstm is support. but somehow, in opset13 there is a weired op introduced.

Do u know which operation caused that op?

csukuangfj commented 1 year ago

You should using opset higher than 12. After that lstm is support

I know that PyTorch can export LSTM models without projection to onnx. Have you tested that it also supports LSTM with projection?


Do u know which operation caused that op?

I don't know which module causes that issue. But I suggest that you can use the following method to figure it out:

lucasjinreal commented 1 year ago

@csukuangfj Hi, I am not very familiar with your model. Can u tell me what's LSTM with projection? Where is it in your code?

lucasjinreal commented 1 year ago

@csukuangfj I just find most ASR model exporting have this problem, see: https://github.com/microsoft/onnxruntime/issues/11812 . LongFormer also have not just LSTM. Can u help me have a deep look how to export onnx?

csukuangfj commented 1 year ago

Please refer to nn.LSTM from PyTorch. It has an argument proj_size.

lucasjinreal commented 1 year ago

@csukuangfj Hi, please run this code, you will find the root reason it not about lstm:

import torch
from torch import nn

class BadFirst(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x_slice = x[:, 0]
        print(f"x_slice: {x_slice}")
        return x_slice

if __name__ == "__main__":
    m = BadFirst().eval()
    x = torch.rand(10, 5)

    res = m(x) # this works
    torch.onnx.export(m, x, "badfirst.onnx") 
csukuangfj commented 1 year ago

Thanks!

Have you tried to convert LSTM with proj_size > 0 via onnx?

lucasjinreal commented 1 year ago

@csukuangfj My default should set proj_size to d_model:

 self.lstm = ScaledLSTM(
            input_size=d_model,
            hidden_size=rnn_hidden_size,
            proj_size=d_model if rnn_hidden_size > d_model else 0,
            # proj_size=0,
            num_layers=1,
            dropout=0.0,
        )

i tried proj_size == 0, but model is dimension wrong.

this is very weired. Do u got any idea?

lucasjinreal commented 1 year ago

@csukuangfj Oh. I got same msg as your now:

torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator LSTM, LSTMs with projections. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues  [Caused by the value 'input.11 defined in (%input.11 : Float(23, 1, 512, strides=[512, 11776, 1], requires_grad=0, device=cpu) = onnx::Transpose[perm=[1, 0, 2]](%196), scope: aural.modeling.encoders.rnn.RNN::

So this LSTM version didn't support ONNX?

csukuangfj commented 1 year ago

LSTM with projection is available only for torch >= 1.8.1, I think.

It is not unusual that it is not supported by ONNX. Even ncnn does not support LSTM with projection.


The current lstm-transducer model does support exporting to ONNX because of usage of projection.

lucasjinreal commented 1 year ago

@csukuangfj what did u modified in your forked ncnn version?

csukuangfj commented 1 year ago

@csukuangfj what did u modified in your forked ncnn version?

For instance, I added support for LSTMs with projections.

Please have a look at the source code if you are interested.

lucasjinreal commented 1 year ago

@csukuangfj I have a customized ncnn as well, so I just need merge your lstm implementation to my branch is OK?

csukuangfj commented 1 year ago

@csukuangfj I have a customized ncnn as well, so I just need merge your lstm implementation to my branch is OK?

Yes, but I suggest that you test it with your existing models before you merge.

lucasjinreal commented 1 year ago

@csukuangfj thanks, do u have any idea when will onnx suport lstm with projection? Or does the model with project is necessary?

lucasjinreal commented 1 year ago

@csukuangfj Hi, I got this error when runing: layer prim::TupleUnpack not exists or registered, do ut know why?

7767517
267 379
Input                    in0                      0 1 in0
Input                    in1                      0 1 in1
Input                    in2                      0 1 in2
prim::TupleUnpack        pnnx_10                  1 2 in2 3 4
Split                    splitncnn_1              1 12 4 5 6 7 8 9 10 11 12 13 14 15 16
Split                    splitncnn_0              1 12 3 17 18 19 20 21 22 23 24 25 26 27 28
ExpandDims               unsqueeze_96             1 1 in0 29 -23303=1,0
Convolution              conv_15                  1 1 29 30 0=8 1=3 11=3 12=1 13=1 14=0 2=1 3=1 4=0 5=1 6=72
Split                    splitncnn_2              1 2 30 31 32
BinaryOp                 sub_0                    1 1 31 33 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_0                1 1 33 34
BinaryOp                 mul_1                    2 1 32 34 35 0=2
Convolution              conv_16                  1 1 35 36 0=32 1=3 11=3 12=1 13=2 14=0 2=1 3=2 4=0 5=1 6=2304
Split                    splitncnn_3              1 2 36 37 38
BinaryOp                 sub_2                    1 1 37 39 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_1                1 1 39 40
BinaryOp                 mul_3                    2 1 38 40 41 0=2
Convolution              conv_17                  1 1 41 42 0=128 1=3 11=3 12=1 13=2 14=0 2=1 3=2 4=0 5=1 6=36864
Split                    splitncnn_4              1 2 42 43 44
BinaryOp                 sub_4                    1 1 43 45 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_2                1 1 45 46
BinaryOp                 mul_5                    2 1 44 46 47 0=2
Permute                  permute_93               1 1 47 48 0=2
Reshape                  reshape_55               1 1 48 49 0=2304 1=-1
InnerProduct             linear_18                1 1 49 50 0=512 1=1 2=1179648
Split                    splitncnn_5              1 3 50 51 52 53
BinaryOp                 mul_6                    2 1 51 52 54 0=2
Reduction                mean_80                  1 1 54 55 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_7                    1 1 55 56 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_8                    1 1 56 57 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_9                    2 1 53 57 58 0=2
BinaryOp                 sub_10                   1 1 in1 59 0=1 1=1 2=3.000000e+00
BinaryOp                 div_11                   1 1 59 60 0=3 1=1 2=2.000000e+00
aten::floor              pnnx_60                  1 1 60 61
BinaryOp                 sub_12                   1 1 61 62 0=1 1=1 2=1.000000e+00
BinaryOp                 div_13                   1 1 62 63 0=3 1=1 2=2.000000e+00
aten::floor              pnnx_65                  1 1 63 out1
Crop                     slice_57                 1 1 16 65 -23310=1,1 -23311=1,0 -23309=1,0
Crop                     slice_56                 1 1 28 66 -23310=1,1 -23311=1,0 -23309=1,0
Split                    splitncnn_6              1 2 58 67 68
LSTM                     lstm_43                  3 3 68 66 65 69 70 71 0=1024 1=2097152 2=0
BinaryOp                 add_14                   2 1 69 67 72 0=0
Split                    splitncnn_7              1 2 72 73 74
InnerProduct             linear_19                1 1 74 75 0=2048 1=1 2=1048576
Split                    splitncnn_8              1 2 75 76 77
BinaryOp                 sub_15                   1 1 76 78 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_3                1 1 78 79
BinaryOp                 mul_16                   2 1 77 79 80 0=2
InnerProduct             linear_20                1 1 80 81 0=512 1=1 2=1048576
BinaryOp                 add_17                   2 1 73 81 82 0=0
Split                    splitncnn_9              1 3 82 83 84 85
BinaryOp                 mul_18                   2 1 83 84 86 0=2
Reduction                mean_81                  1 1 86 87 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_19                   1 1 87 88 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_20                   1 1 88 89 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_21                   2 1 85 89 90 0=2
Split                    splitncnn_10             1 2 90 91 92
Crop                     slice_59                 1 1 15 93 -23310=1,2 -23311=1,0 -23309=1,1
Crop                     slice_58                 1 1 27 94 -23310=1,2 -23311=1,0 -23309=1,1
LSTM                     lstm_44                  3 3 92 94 93 95 96 97 0=1024 1=2097152 2=0
BinaryOp                 add_22                   2 1 95 91 98 0=0
Split                    splitncnn_11             1 2 98 99 100
InnerProduct             linear_21                1 1 100 101 0=2048 1=1 2=1048576
Split                    splitncnn_12             1 2 101 102 103
BinaryOp                 sub_23                   1 1 102 104 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_4                1 1 104 105
BinaryOp                 mul_24                   2 1 103 105 106 0=2
InnerProduct             linear_22                1 1 106 107 0=512 1=1 2=1048576
BinaryOp                 add_25                   2 1 99 107 108 0=0
Split                    splitncnn_13             1 3 108 109 110 111
BinaryOp                 mul_26                   2 1 109 110 112 0=2
Reduction                mean_82                  1 1 112 113 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_27                   1 1 113 114 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_28                   1 1 114 115 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_29                   2 1 111 115 116 0=2
Split                    splitncnn_14             1 2 116 117 118
Crop                     slice_61                 1 1 14 119 -23310=1,3 -23311=1,0 -23309=1,2
Crop                     slice_60                 1 1 26 120 -23310=1,3 -23311=1,0 -23309=1,2
LSTM                     lstm_45                  3 3 118 120 119 121 122 123 0=1024 1=2097152 2=0
BinaryOp                 add_30                   2 1 121 117 124 0=0
Split                    splitncnn_15             1 2 124 125 126
InnerProduct             linear_23                1 1 126 127 0=2048 1=1 2=1048576
Split                    splitncnn_16             1 2 127 128 129
BinaryOp                 sub_31                   1 1 128 130 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_5                1 1 130 131
BinaryOp                 mul_32                   2 1 129 131 132 0=2
InnerProduct             linear_24                1 1 132 133 0=512 1=1 2=1048576
BinaryOp                 add_33                   2 1 125 133 134 0=0
Split                    splitncnn_17             1 3 134 135 136 137
BinaryOp                 mul_34                   2 1 135 136 138 0=2
Reduction                mean_83                  1 1 138 139 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_35                   1 1 139 140 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_36                   1 1 140 141 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_37                   2 1 137 141 142 0=2
Split                    splitncnn_18             1 2 142 143 144
Crop                     slice_63                 1 1 13 145 -23310=1,4 -23311=1,0 -23309=1,3
Crop                     slice_62                 1 1 25 146 -23310=1,4 -23311=1,0 -23309=1,3
LSTM                     lstm_46                  3 3 144 146 145 147 148 149 0=1024 1=2097152 2=0
BinaryOp                 add_38                   2 1 147 143 150 0=0
Split                    splitncnn_19             1 2 150 151 152
InnerProduct             linear_25                1 1 152 153 0=2048 1=1 2=1048576
Split                    splitncnn_20             1 2 153 154 155
BinaryOp                 sub_39                   1 1 154 156 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_6                1 1 156 157
BinaryOp                 mul_40                   2 1 155 157 158 0=2
InnerProduct             linear_26                1 1 158 159 0=512 1=1 2=1048576
BinaryOp                 add_41                   2 1 151 159 160 0=0
Split                    splitncnn_21             1 3 160 161 162 163
BinaryOp                 mul_42                   2 1 161 162 164 0=2
Reduction                mean_84                  1 1 164 165 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_43                   1 1 165 166 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_44                   1 1 166 167 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_45                   2 1 163 167 168 0=2
Split                    splitncnn_22             1 2 168 169 170
Crop                     slice_65                 1 1 12 171 -23310=1,5 -23311=1,0 -23309=1,4
Crop                     slice_64                 1 1 24 172 -23310=1,5 -23311=1,0 -23309=1,4
LSTM                     lstm_47                  3 3 170 172 171 173 174 175 0=1024 1=2097152 2=0
BinaryOp                 add_46                   2 1 173 169 176 0=0
Split                    splitncnn_23             1 2 176 177 178
InnerProduct             linear_27                1 1 178 179 0=2048 1=1 2=1048576
Split                    splitncnn_24             1 2 179 180 181
BinaryOp                 sub_47                   1 1 180 182 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_7                1 1 182 183
BinaryOp                 mul_48                   2 1 181 183 184 0=2
InnerProduct             linear_28                1 1 184 185 0=512 1=1 2=1048576
BinaryOp                 add_49                   2 1 177 185 186 0=0
Split                    splitncnn_25             1 3 186 187 188 189
BinaryOp                 mul_50                   2 1 187 188 190 0=2
Reduction                mean_85                  1 1 190 191 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_51                   1 1 191 192 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_52                   1 1 192 193 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_53                   2 1 189 193 194 0=2
Split                    splitncnn_26             1 2 194 195 196
Crop                     slice_67                 1 1 11 197 -23310=1,6 -23311=1,0 -23309=1,5
Crop                     slice_66                 1 1 23 198 -23310=1,6 -23311=1,0 -23309=1,5
LSTM                     lstm_48                  3 3 196 198 197 199 200 201 0=1024 1=2097152 2=0
BinaryOp                 add_54                   2 1 199 195 202 0=0
Split                    splitncnn_27             1 2 202 203 204
InnerProduct             linear_29                1 1 204 205 0=2048 1=1 2=1048576
Split                    splitncnn_28             1 2 205 206 207
BinaryOp                 sub_55                   1 1 206 208 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_8                1 1 208 209
BinaryOp                 mul_56                   2 1 207 209 210 0=2
InnerProduct             linear_30                1 1 210 211 0=512 1=1 2=1048576
BinaryOp                 add_57                   2 1 203 211 212 0=0
Split                    splitncnn_29             1 3 212 213 214 215
BinaryOp                 mul_58                   2 1 213 214 216 0=2
Reduction                mean_86                  1 1 216 217 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_59                   1 1 217 218 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_60                   1 1 218 219 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_61                   2 1 215 219 220 0=2
Split                    splitncnn_30             1 2 220 221 222
Crop                     slice_69                 1 1 10 223 -23310=1,7 -23311=1,0 -23309=1,6
Crop                     slice_68                 1 1 22 224 -23310=1,7 -23311=1,0 -23309=1,6
LSTM                     lstm_49                  3 3 222 224 223 225 226 227 0=1024 1=2097152 2=0
BinaryOp                 add_62                   2 1 225 221 228 0=0
Split                    splitncnn_31             1 2 228 229 230
InnerProduct             linear_31                1 1 230 231 0=2048 1=1 2=1048576
Split                    splitncnn_32             1 2 231 232 233
BinaryOp                 sub_63                   1 1 232 234 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_9                1 1 234 235
BinaryOp                 mul_64                   2 1 233 235 236 0=2
InnerProduct             linear_32                1 1 236 237 0=512 1=1 2=1048576
BinaryOp                 add_65                   2 1 229 237 238 0=0
Split                    splitncnn_33             1 3 238 239 240 241
BinaryOp                 mul_66                   2 1 239 240 242 0=2
Reduction                mean_87                  1 1 242 243 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_67                   1 1 243 244 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_68                   1 1 244 245 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_69                   2 1 241 245 246 0=2
Split                    splitncnn_34             1 2 246 247 248
Crop                     slice_71                 1 1 9 249 -23310=1,8 -23311=1,0 -23309=1,7
Crop                     slice_70                 1 1 21 250 -23310=1,8 -23311=1,0 -23309=1,7
LSTM                     lstm_50                  3 3 248 250 249 251 252 253 0=1024 1=2097152 2=0
BinaryOp                 add_70                   2 1 251 247 254 0=0
Split                    splitncnn_35             1 2 254 255 256
InnerProduct             linear_33                1 1 256 257 0=2048 1=1 2=1048576
Split                    splitncnn_36             1 2 257 258 259
BinaryOp                 sub_71                   1 1 258 260 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_10               1 1 260 261
BinaryOp                 mul_72                   2 1 259 261 262 0=2
InnerProduct             linear_34                1 1 262 263 0=512 1=1 2=1048576
BinaryOp                 add_73                   2 1 255 263 264 0=0
Split                    splitncnn_37             1 3 264 265 266 267
BinaryOp                 mul_74                   2 1 265 266 268 0=2
Reduction                mean_88                  1 1 268 269 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_75                   1 1 269 270 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_76                   1 1 270 271 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_77                   2 1 267 271 272 0=2
Split                    splitncnn_38             1 2 272 273 274
Crop                     slice_73                 1 1 8 275 -23310=1,9 -23311=1,0 -23309=1,8
Crop                     slice_72                 1 1 20 276 -23310=1,9 -23311=1,0 -23309=1,8
LSTM                     lstm_51                  3 3 274 276 275 277 278 279 0=1024 1=2097152 2=0
BinaryOp                 add_78                   2 1 277 273 280 0=0
Split                    splitncnn_39             1 2 280 281 282
InnerProduct             linear_35                1 1 282 283 0=2048 1=1 2=1048576
Split                    splitncnn_40             1 2 283 284 285
BinaryOp                 sub_79                   1 1 284 286 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_11               1 1 286 287
BinaryOp                 mul_80                   2 1 285 287 288 0=2
InnerProduct             linear_36                1 1 288 289 0=512 1=1 2=1048576
BinaryOp                 add_81                   2 1 281 289 290 0=0
Split                    splitncnn_41             1 3 290 291 292 293
BinaryOp                 mul_82                   2 1 291 292 294 0=2
Reduction                mean_89                  1 1 294 295 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_83                   1 1 295 296 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_84                   1 1 296 297 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_85                   2 1 293 297 298 0=2
Split                    splitncnn_42             1 2 298 299 300
Crop                     slice_75                 1 1 7 301 -23310=1,10 -23311=1,0 -23309=1,9
Crop                     slice_74                 1 1 19 302 -23310=1,10 -23311=1,0 -23309=1,9
LSTM                     lstm_52                  3 3 300 302 301 303 304 305 0=1024 1=2097152 2=0
BinaryOp                 add_86                   2 1 303 299 306 0=0
Split                    splitncnn_43             1 2 306 307 308
InnerProduct             linear_37                1 1 308 309 0=2048 1=1 2=1048576
Split                    splitncnn_44             1 2 309 310 311
BinaryOp                 sub_87                   1 1 310 312 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_12               1 1 312 313
BinaryOp                 mul_88                   2 1 311 313 314 0=2
InnerProduct             linear_38                1 1 314 315 0=512 1=1 2=1048576
BinaryOp                 add_89                   2 1 307 315 316 0=0
Split                    splitncnn_45             1 3 316 317 318 319
BinaryOp                 mul_90                   2 1 317 318 320 0=2
Reduction                mean_90                  1 1 320 321 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_91                   1 1 321 322 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_92                   1 1 322 323 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_93                   2 1 319 323 324 0=2
Split                    splitncnn_46             1 2 324 325 326
Crop                     slice_77                 1 1 6 327 -23310=1,11 -23311=1,0 -23309=1,10
Crop                     slice_76                 1 1 18 328 -23310=1,11 -23311=1,0 -23309=1,10
LSTM                     lstm_53                  3 3 326 328 327 329 330 331 0=1024 1=2097152 2=0
BinaryOp                 add_94                   2 1 329 325 332 0=0
Split                    splitncnn_47             1 2 332 333 334
InnerProduct             linear_39                1 1 334 335 0=2048 1=1 2=1048576
Split                    splitncnn_48             1 2 335 336 337
BinaryOp                 sub_95                   1 1 336 338 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_13               1 1 338 339
BinaryOp                 mul_96                   2 1 337 339 340 0=2
InnerProduct             linear_40                1 1 340 341 0=512 1=1 2=1048576
BinaryOp                 add_97                   2 1 333 341 342 0=0
Split                    splitncnn_49             1 3 342 343 344 345
BinaryOp                 mul_98                   2 1 343 344 346 0=2
Reduction                mean_91                  1 1 346 347 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_99                   1 1 347 348 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_100                  1 1 348 349 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_101                  2 1 345 349 350 0=2
Split                    splitncnn_50             1 2 350 351 352
Crop                     slice_79                 1 1 5 353 -23310=1,12 -23311=1,0 -23309=1,11
Crop                     slice_78                 1 1 17 354 -23310=1,12 -23311=1,0 -23309=1,11
LSTM                     lstm_54                  3 3 352 354 353 355 356 357 0=1024 1=2097152 2=0
BinaryOp                 add_102                  2 1 355 351 358 0=0
Split                    splitncnn_51             1 2 358 359 360
InnerProduct             linear_41                1 1 360 361 0=2048 1=1 2=1048576
Split                    splitncnn_52             1 2 361 362 363
BinaryOp                 sub_103                  1 1 362 364 0=1 1=1 2=1.000000e+00
Sigmoid                  sigmoid_14               1 1 364 365
BinaryOp                 mul_104                  2 1 363 365 366 0=2
InnerProduct             linear_42                1 1 366 367 0=512 1=1 2=1048576
BinaryOp                 add_105                  2 1 359 367 368 0=0
Split                    splitncnn_53             1 3 368 369 370 371
BinaryOp                 mul_106                  2 1 369 370 372 0=2
Reduction                mean_92                  1 1 372 373 0=3 1=0 -23303=1,-1 4=1 5=1
BinaryOp                 add_107                  1 1 373 374 0=0 1=1 2=2.500000e-01
BinaryOp                 pow_108                  1 1 374 375 0=6 1=1 2=-5.000000e-01
BinaryOp                 mul_109                  2 1 371 375 out0 0=2
Concat                   cat_0                    12 1 71 97 123 149 175 201 227 253 279 305 331 357 out3 0=0
Concat                   cat_1                    12 1 70 96 122 148 174 200 226 252 278 304 330 356 out2 0=0

my encoder ncnn