k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Fix for rnnt_loss.py #1177

Closed yfyeung closed 1 year ago

danpovey commented 1 year ago

This PR is to fix an issue where a lot of memory is used in the backward pass of the simple RNNT loss. It can cause random-seeming failures where, after some time, training ends with a message like this:

    Variable._execution_engine.run_backward(
RuntimeError: CUDA out of memory. Tried to allocate 5.37 GiB (GPU 0; 31.75 GiB total capacity; 17.47 GiB already allocated; 12.64 GiB free; 17.83 GiB reserved in total by PyTorch)

(note, there remains a mystery why, often, it seems to be asking for much less memory than the device has free (that 12.64 GiB number comes from the device_free of cudaMemGetInfo(&device_free, &device_total))... possibly this has to do with other things using the machine; but regardless, the fact is that way more memory is being used in the backward pass than really needs to be used.)