k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Small fix for rnn_loss.py #1183

Closed yfyeung closed 1 year ago

danpovey commented 1 year ago

This is another fix for a memory issue in rnnt_loss.py, similar to #1177- the problem is that a too-large zero tensor is used in the backward pass, which can lead to OOM errors in training for large batch sizes. I discovered the problem using a version of PyTorch that was built with debug info. I ran a training using large batch size (1450) and 1 job, using gdb --args.

 gdb --args python3 ./pruned_transducer_stateless7/train.py --master-port 43021 --world-size 1 --num-workers 0 --num-epochs 24 --full-libri 0 --exp\
-dir pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full --max-duration 1450 --use-fp16 True --decoder-dim 512 --joiner-dim 512 --start-epoch=1 --base-lr=0.035
(gdb) catch throw
(gdb) r

When it stopped at the point where it was about to issue this error:

RuntimeError: CUDA out of memory. Tried to allocate 4.89 GiB (GPU 0; 31.75 GiB total capacity; 23.75 GiB already allocated; 4.71 GiB free; 25.86 GiB reserved in total by PyTorch)
<snip>
2023-04-28 02:09:35,802 INFO [checkpoint.py:75] Saving checkpoint to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/bad-model-0.pt
2023-04-28 02:09:45,045 INFO [train.py:1299] Saving batch to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/batch-d9cd8db7-a730-5db6-e534-90be42e420b6.pt
2023-04-28 02:09:45,224 INFO [train.py:1305] features shape: torch.Size([81, 1777, 80])

... I went to stack frame 65 which had size info:

#62 0x00007fffcb7c63ff in c10::Dispatcher::callWithDispatchKey<at::Tensor, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, \
long, at::Tensor const&, bool)> const&, c10::DispatchKey, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (this=0x7fffd86eada0 <c10::Dispatcher::singleton()::_singleton>, op=...,
    dispatchKey=c10::DispatchKey::AutogradCUDA) at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:376
#63 0x00007fffcb7204c3 in c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, long, at::Tenso\
r const&, bool)> const&, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (this=0x7fffd86eada0 <c10::Dispatcher::singleton()::_singleton>, op=...)
    at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:387
#64 0x00007fffcb6ae3a4 in c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool)>::call(at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (
    this=0x7fffd86f4d00 <at::gather_backward(at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool)::op>, args#0=..., args#1=..., args#2=2, args#3=..., args#4=false)
    at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:327
#65 0x00007fffcb655052 in at::gather_backward (grad=..., self=..., dim=2, index=..., sparse_grad=false) at /star-fj/fangjun/open-source/pytorch/build/aten/src/ATen/Functions.cpp:8184
#66 0x00007fffcd4e3f17 in torch::autograd::generated::GatherBackward::apply (this=0x9b102f40, grads=...) at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/generated/Functions.cpp:1563
#67 0x00007fffcdc8ef62 in torch::autograd::Node::operator() (this=0x9b102f40, inputs=...) at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/function.h:155
#68 0x00007fffcdc89027 in torch::autograd::call_function (graph_task=std::shared_ptr<torch::autograd::GraphTask> (use count 3, weak count 4) = {...}, func=0x9b102f40, inputBuffer=...)
    at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/engine.cpp:676
#69 0x00007fffcdc89582 in torch::autograd::Engine::evaluate_function (this=0x7fffdaef21c0 <torch::autograd::python::PythonEngine::get_python_engine()::engine>,
    graph_task=std::shared_ptr<torch::autograd::GraphTask> (use count 3, weak count 4) = {...}, func=0x9b102f40, inputs=..., cpu_ready_queue=std::shared_ptr<torch::autograd::ReadyQueue> (use count 2, weak count 0) = {..\
.})

and printed out the size info:

(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[0]
$36 = 81
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[1]
$37 = 443
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[2]
$38 = 5
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[3]
$39 = 512
(gdb) p ((long[4]) self.impl_.target_.sizes_.BeginX)
$40 = {2574253384, 2574253416, 2574253424, 81}
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[0]
$41 = 81 # B
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[1]
$42 = 443 # T
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[2]
$43 = 143 # S
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[3]
$44 = 512 # C

(I figured out these expression from looking at the types printed out for these variables.) From the size info I figured out which part of the code it related to, it was a torch.gather expression.