burchim / EfficientConformer

[ASRU 2021] Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition
https://arxiv.org/abs/2109.01163
Apache License 2.0
210 stars 32 forks source link

An error when i try to train a Transducer model. #24

Closed ZhengWenrui closed 1 year ago

ZhengWenrui commented 1 year ago

When i try to train a Transducer model, i can't successfully run the main.py. I meet the error as following:

Traceback (most recent call last): File "main.py", line 222, in main(0, args) File "main.py", line 121, in main model.fit(dataset_train, File "/data3/zwr/Procedure/EfficientConformer/models/model.py", line 344, in fit raise e File "/data3/zwr/Procedure/EfficientConformer/models/model.py", line 241, in fit loss_mini = self.criterion(batch, pred) File "/data3/zwr/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data3/zwr/Procedure/EfficientConformer/models/losses.py", line 36, in forward loss = warp_rnnt.rnnt_loss( File "/data3/zwr/py38/lib/python3.8/site-packages/warp_rnnt/init.py", line 130, in rnnt_loss costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda) File "/data3/zwr/py38/lib/python3.8/site-packages/warp_rnnt/init.py", line 13, in forward costs, ctx.grads = core.rnnt_loss( RuntimeError: rnnt_loss status 1 Segmentation fault (core dumped)

I changed the version of the warp-rnnt, but it didn't work. What can i do to solve the problem? Thank you very much!

burchim commented 1 year ago

Hi, warp-rnnt may be deprecated for the last versions of PyTorch. I was using PyTorch 1.8 at the time. I would advise you to use torchaudio.transforms.RNNTLoss instead here

class LossRNNT(torchaudio.transforms.RNNTLoss):

    def __init__(self, blank=0, clamp=-1, reduction="mean", verbose=False):
        super(RNNTLoss, self).__init__(blank=blank, clamp=clamp, reduction=reduction)
        self.verbose = verbose

    def forward(self, batch, pred):

        # Unpack batch
        x, y, x_len, y_len = batch

        # Unpack Outputs (B, T, U + 1, V) and (B,)
        logits, logits_len, _ = pred

        # Verbose
        if self.verbose:
            print("logits:")
            print(logits.size(), logits_len)
            print("y:")
            print(y.size(), y_len)

        # Compute Loss
        loss = super(RNNTLoss, self).forward(
            logits=logits,
            targets=y.int(),
            logit_lengths=logits_len.int(),
            target_lengths=y_len.int()
        )

        return loss
ZhengWenrui commented 1 year ago

I is using PyTorch 1.8.1. And i changed the "LossRNNT", but it didn't work. Can you tell me your version of warp_rnnt?