This is another fix for a memory issue in rnnt_loss.py, similar to #1177- the problem is that a too-large zero tensor is used in the backward pass, which can lead to OOM errors in training for large batch sizes. I discovered the problem using a version of PyTorch that was built with debug info. I ran a training using large batch size (1450) and 1 job, using gdb --args.
When it stopped at the point where it was about to issue this error:
RuntimeError: CUDA out of memory. Tried to allocate 4.89 GiB (GPU 0; 31.75 GiB total capacity; 23.75 GiB already allocated; 4.71 GiB free; 25.86 GiB reserved in total by PyTorch)
<snip>
2023-04-28 02:09:35,802 INFO [checkpoint.py:75] Saving checkpoint to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/bad-model-0.pt
2023-04-28 02:09:45,045 INFO [train.py:1299] Saving batch to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/batch-d9cd8db7-a730-5db6-e534-90be42e420b6.pt
2023-04-28 02:09:45,224 INFO [train.py:1305] features shape: torch.Size([81, 1777, 80])
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[0]
$36 = 81
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[1]
$37 = 443
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[2]
$38 = 5
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[3]
$39 = 512
(gdb) p ((long[4]) self.impl_.target_.sizes_.BeginX)
$40 = {2574253384, 2574253416, 2574253424, 81}
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[0]
$41 = 81 # B
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[1]
$42 = 443 # T
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[2]
$43 = 143 # S
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[3]
$44 = 512 # C
(I figured out these expression from looking at the types printed out for these variables.)
From the size info I figured out which part of the code it related to, it was a torch.gather expression.
This is another fix for a memory issue in rnnt_loss.py, similar to #1177- the problem is that a too-large zero tensor is used in the backward pass, which can lead to OOM errors in training for large batch sizes. I discovered the problem using a version of PyTorch that was built with debug info. I ran a training using large batch size (1450) and 1 job, using gdb --args.
When it stopped at the point where it was about to issue this error:
... I went to stack frame 65 which had size info:
and printed out the size info:
(I figured out these expression from looking at the types printed out for these variables.) From the size info I figured out which part of the code it related to, it was a torch.gather expression.