Open snakers4 opened 2 years ago
@burchim By the way, since you used this loss, did you encounter anything of this sort in your work?
Hi @snakers4! Yes I had a similar problem with 4 GPU devices where the rnnt loss was properly computed on the first devices but 0 on the others. I don't really remember what was the exact cause but it had something to with tensor devices. Maybe the frames / label lengths.
I also recently experimented replacing it with the official torchaudio.transforms.RNNTLoss loss from torchaudio 0.10.0. Was working very well but I didn't try to do a full training with it.
Thanks for the heads up about the torchaudio
loss!
I remember seeing it sometime ago, but I totally forgot about it.
@burchim
By the way, did you have RuntimeError: input length mismatch
when migrating from warp-rnnt
towards torchaudio
?
Yes, this means that logits / target lengths tensors do not match the logits / target tensors. If you have logits lengths longer than your logits tensor for instance.
Because I used the targets lengths instead of logits lengths, stupid error
Thanks for the heads up about the torchaudio loss!
@snakers4 You may find https://github.com/danpovey/fast_rnnt useful.
@1ytic Hi,
So far I have been able to use the loss with DDP on a single GPU , it behaves more or less as expected.
But when I use more than 1 device, the following happens:
GPU-0
loss is calculated properlyGPU-1
loss is close to zero for each batchI checked the input tensors, devices, tensor values, etc - so far everything seems to be identical for
GPU-0
and other GPUs.