Closed ZhengWenrui closed 1 year ago
Hi, warp-rnnt may be deprecated for the last versions of PyTorch. I was using PyTorch 1.8 at the time.
I would advise you to use torchaudio.transforms.RNNTLoss
instead here
class LossRNNT(torchaudio.transforms.RNNTLoss):
def __init__(self, blank=0, clamp=-1, reduction="mean", verbose=False):
super(RNNTLoss, self).__init__(blank=blank, clamp=clamp, reduction=reduction)
self.verbose = verbose
def forward(self, batch, pred):
# Unpack batch
x, y, x_len, y_len = batch
# Unpack Outputs (B, T, U + 1, V) and (B,)
logits, logits_len, _ = pred
# Verbose
if self.verbose:
print("logits:")
print(logits.size(), logits_len)
print("y:")
print(y.size(), y_len)
# Compute Loss
loss = super(RNNTLoss, self).forward(
logits=logits,
targets=y.int(),
logit_lengths=logits_len.int(),
target_lengths=y_len.int()
)
return loss
I is using PyTorch 1.8.1. And i changed the "LossRNNT", but it didn't work. Can you tell me your version of warp_rnnt?
When i try to train a Transducer model, i can't successfully run the main.py. I meet the error as following:
Traceback (most recent call last): File "main.py", line 222, in
main(0, args)
File "main.py", line 121, in main
model.fit(dataset_train,
File "/data3/zwr/Procedure/EfficientConformer/models/model.py", line 344, in fit
raise e
File "/data3/zwr/Procedure/EfficientConformer/models/model.py", line 241, in fit
loss_mini = self.criterion(batch, pred)
File "/data3/zwr/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data3/zwr/Procedure/EfficientConformer/models/losses.py", line 36, in forward
loss = warp_rnnt.rnnt_loss(
File "/data3/zwr/py38/lib/python3.8/site-packages/warp_rnnt/init.py", line 130, in rnnt_loss
costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda)
File "/data3/zwr/py38/lib/python3.8/site-packages/warp_rnnt/init.py", line 13, in forward
costs, ctx.grads = core.rnnt_loss(
RuntimeError: rnnt_loss status 1
Segmentation fault (core dumped)
I changed the version of the warp-rnnt, but it didn't work. What can i do to solve the problem? Thank you very much!