Closed shirazb closed 5 years ago
We appreciate your interest in PyTorch. Please use https://discuss.pytorch.org/ for such questions.
Hi,
My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.
I am attempting to implement memory checkpointing as done in
torch.utils.checkpoint
, except allowing for multiple recomputations. However, there a couple things from the implementation that I'm not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.
- Why detach the inputs in the backward before recomputing the forward? Would you get something like the gradient of the forward being accumulated twice into the input's gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?
I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs
1 -> drop(2) -> 3
, where the module2
itself performs2a -> drop(2b)
.drop
is like a higher-order model whose forward simply performscheckpoint(child_model, x)
. Thus, running2
should drop2a
and2b
in the forward pass; and in the backward pass, recompute2a
, recompute2b
and drop it, then recompute2b
a second time, this time actually saving it. Obviously, that's not so smart but its a simple example. I believe the following chain of events will occur in practice:
1
's forward is performed, whose output is propagated todrop(2)
's forward function.drop(2)
's forward invokesCheckpointFunction
's forward, which saves (checkpoints) the input, and performs2
's forward without tracking gradient. The intent is that both the outputs of2a
and2b
will be dropped at this stage.2a
performs its forward and the output is passed todrop(2b)
's forward.- a. If I set up this 'drop' wrapper to not detach its input before passing to the child model, the
CheckpointFunction
forward will throw an error about none of the inputs haverequires_grad=True
. b. If I do detach the input, I suspect this will not work either: it will save the input to2b
thus checkpointing it, which, as mentioned in step 3, I do not want to happen! On the other hand, maybe it will be freed because of how everything is set up and autograd's reference counting? Maybe, asdrop(2b)
saves its input, and2a
was run with no grad, you get an unreachable cycle between them that will get garbage collected? I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.If my anaylsis is right, I will have to implement this
drop
operator from scratch such that it avoids this behaviour, correct?Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.
@shirazb have you found the answer?
Hey @knsong,
I updated my answer on the PyTorch forums here:
@shirazb Did you ever considering the swap in&out method e.g Comprehensive techniques of multi-GPU memory optimizationfor deep learning acceleration using pytorch? Recently I have also been working on improve the training throughput by reducing the memory footprint of training a model. Do you have any suggestion about that?
@knsong
To be perfectly frank, this was uni work that I haven't looked at for a number of months. I have looked at swapping, but not for a while, and not for PyTorch specifically. I am not familiar with that paper in particular.
I do recall there may have been some work [1] on implementing something similar in PyTorch to TensorFlow Large Model Support (TFLMS) [2], which was a swapping solution for TensorFlow, but I never really looked into it as I was focusing on checkpointing.
--
With regards to your more general question about how to reduce the memory footprint, the single-recomputation O(sqrt(N))
policy already implemented in torch.utils.checkpoint
[3] should give significant savings at only a modest cost. See, for example, Chen's paper for experimental results of the same policy implemented in MxNet [4].
For a good, centralised resource that links to many of the other good resources on Deep Learning memory optimisations, I would suggest my thesis [5]. Note it was rather hastily written though so may contain some errors...
I will list some sections below that I think may be of interest to you. Section 5.4 may be of particular interest as it discusses both checkpointing and swapping together (more detail below).
Section 2.5 introduces checkpointing, multiple recomputations, and dynamic programming checkpointing that solves for the optimal policy that minimises compute given a memory constraint. It also briefly discusses the checkpointing approaches in the Automatic Differentiation (AD) community, who have been doing it since the early 90's at least! If you're not familiar with AD, Baydin et al. [6] have an excellent paper that explains what it is and how it is used in ML. Section 2.2 of my thesis gives a much abridged version of this, but I really recommend their paper.
Chapter 3 details the main contributions of my thesis: generalising the aforementioned dynamic programming checkpointing technique to take into account the precise costs of each layer, as the original paper assumes uniform costs. There is a lot of rigmarole in this, but Section 3.2 may be of interest to you, as it discusses how to implement multiple recomputations in PyTorch. The problem is that it may not make sense without reading many of the previous sections of the thesis. Also, it requires a little knowledge of Autograd internals. Section 2.3 gives some introduction to this, which was based off a great video I found on YouTube by Elliot Waite [7].
Section 4.1 explores the compute-memory tradeoff checkpointing provides (with multiple recomputations allowed); and shows that single-recomputation O(sqrt(N))
is pretty good, especially for uniform cost networks, and scales well to very deep networks.
Section 5.2 discusses the limitation of checkpointing to sequences and the difficulties of trying to generalise it to more arbitrary graphs. It mentions a couple recent papers that have made headway in this.
Section 5.4 discusses the role of checkpointing in light of other memory optimisations; that is, CPU-GPU swapping and the clever choice of convolutional operator implementation. It references many papers that will explain them in more detail, but does not give a full survey. Some could have PyTorch implementations (I've forgotten), but most are Proof-of-Concept systems. It also discusses more literature from the AD community on combining swapping and checkpointing that looks very promising. Essentially, they solve a more general checkpointing problem where there is a hierarchy of memory pools, so swapping to CPU memory is simply treated as checkpointing to the second layer of memory.
--
Hope that gave you some ideas, not sure what you were after. Certainly, in terms of swapping techniques already implemented in PyTorch specifically, I am not myself aware of much.
--
Bibliography
[1] TFLMS in PyTorch https://discuss.pytorch.org/t/thoughts-on-use-of-cpu-ram-as-a-swap-for-gpu/13193
[2] TFLMS paper https://arxiv.org/pdf/1807.02037.pdf
[3] torch.utils.checkpoint
https://pytorch.org/docs/stable/checkpoint.html
[4] O(sqrt(N))
Checkpointing paper https://arxiv.org/pdf/1604.06174.pdf
[5] My thesis - 'Precise Optimal Checkpointing for Memory-Efficient Deep Learning' https://github.com/shirazb/final-year-thesis/tree/master/out
[6] AD in ML Survery Paper https://arxiv.org/pdf/1502.05767.pdf
[7] Autograd Internals YouTube Video https://www.youtube.com/watch?v=MswxJw-8PvE
Hi,
My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.
I am attempting to implement memory checkpointing as done in
torch.utils.checkpoint
, except allowing for multiple recomputations. However, there a couple things from the implementation that I'm not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.Why detach the inputs in the backward before recomputing the forward? Would you get something like the gradient of the forward being accumulated twice into the input's gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?
I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs
1 -> drop(2) -> 3
, where the module2
itself performs2a -> drop(2b)
.drop
is like a higher-order model whose forward simply performscheckpoint(child_model, x)
. Thus, running2
should drop2a
and2b
in the forward pass; and in the backward pass, recompute2a
, recompute2b
and drop it, then recompute2b
a second time, this time actually saving it. Obviously, that's not so smart but its a simple example.I believe the following chain of events will occur in practice:
1
's forward is performed, whose output is propagated todrop(2)
's forward function.drop(2)
's forward invokesCheckpointFunction
's forward, which saves (checkpoints) the input, and performs2
's forward without tracking gradient. The intent is that both the outputs of2a
and2b
will be dropped at this stage.2a
performs its forward and the output is passed todrop(2b)
's forward.a. If I set up this 'drop' wrapper to not detach its input before passing to the child model, the
CheckpointFunction
forward will throw an error about none of the inputs haverequires_grad=True
.b. If I do detach the input, I suspect this will not work either: it will save the input to
2b
thus checkpointing it, which, as mentioned in step 3, I do not want to happen!On the other hand, maybe it will be freed because of how everything is set up and autograd's reference counting?
Maybe, as
drop(2b)
saves its input, and2a
was run with no grad, you get an unreachable cycle between them that will get garbage collected?I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.
If my anaylsis is right, I will have to implement this
drop
operator from scratch such that it avoids this behaviour, correct?Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.