k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
792 stars 267 forks source link

pytorch ver. `>=2.1.0` breaks compatibility with all `conformer_ctc` recipes #1610

Open JinZr opened 2 weeks ago

JinZr commented 2 weeks ago

pytorch added two additional parameters to their implementation of the class TransformerDecoder, see https://github.com/pytorch/pytorch/blame/94b328ee4592605f490d422f57ad4747a92ac339/torch/nn/modules/transformer.py#L498 and https://github.com/pytorch/pytorch/pull/97166

the modification breaks all conformer_ctc recipes (and possibly other recipes i haven't looked into), causing

TypeError: TransformerDecoderLayer.forward() got an unexpected keyword argument 'tgt_is_causal'

This error can be bypassed by simply adding memory_is_causal and tgt_is_causal to the forward func of class TransformerDecoderLayer.