Open Garethlomax opened 5 years ago
Issue appears to be as a result of pytorch's computational tree structure for back propagation. need to detatch tensor to allow to bypass hidden memory issues
for debugging a tool to visualize the computational graph of the lstm:
https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py
progress made with cudnn memory - will still explore detatch approaches at a later date. Also will look into garbage collection inneficiencies.
gradient averaging is another potential remedy: https://gchlebus.github.io/2018/06/05/gradient-averaging.html
truncated backprop:
Running snippet to check garbage collection shows a number of large (4096, 4096) tensors, without a good explanation of origin. May either be temporary which have noit been deleted a too costly during training, or be due to a memory leak.
Useful: https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/12
From above : Update 2: "Finally I solved the memory problem! I realized that in each iteration I put the input data in a new tensor, and pytorch generates a new computation graph. That causes the used RAM to grow forever. Then I use a placeholder tensor and copy the data to this tensor, and the RAM always stays at a low level :smile:"