pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.71k stars 22.58k forks source link

Implementing multiple recomputations on top of `torch.utils.checkpoint` #23822

Closed shirazb closed 5 years ago

shirazb commented 5 years ago

Hi,

My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.

I am attempting to implement memory checkpointing as done in torch.utils.checkpoint, except allowing for multiple recomputations. However, there a couple things from the implementation that I'm not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.

  1. Why detach the inputs in the backward before recomputing the forward? Would you get something like the gradient of the forward being accumulated twice into the input's gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?

  2. I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs 1 -> drop(2) -> 3, where the module 2 itself performs 2a -> drop(2b). drop is like a higher-order model whose forward simply performs checkpoint(child_model, x). Thus, running 2 should drop 2a and 2b in the forward pass; and in the backward pass, recompute 2a, recompute 2b and drop it, then recompute 2b a second time, this time actually saving it. Obviously, that's not so smart but its a simple example.

    I believe the following chain of events will occur in practice:

    1. 1's forward is performed, whose output is propagated to drop(2)'s forward function.

    2. drop(2)'s forward invokes CheckpointFunction's forward, which saves (checkpoints) the input, and performs 2's forward without tracking gradient. The intent is that both the outputs of 2a and 2b will be dropped at this stage.

    3. 2a performs its forward and the output is passed to drop(2b)'s forward.

    4. a. If I set up this 'drop' wrapper to not detach its input before passing to the child model, the CheckpointFunction forward will throw an error about none of the inputs have requires_grad=True.

      b. If I do detach the input, I suspect this will not work either: it will save the input to 2b thus checkpointing it, which, as mentioned in step 3, I do not want to happen!

      On the other hand, maybe it will be freed because of how everything is set up and autograd's reference counting?

      Maybe, as drop(2b) saves its input, and 2a was run with no grad, you get an unreachable cycle between them that will get garbage collected?

      I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.

    If my anaylsis is right, I will have to implement this drop operator from scratch such that it avoids this behaviour, correct?

Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.

pbelevich commented 5 years ago

We appreciate your interest in PyTorch. Please use https://discuss.pytorch.org/ for such questions.

knsong commented 4 years ago

Hi,

My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.

I am attempting to implement memory checkpointing as done in torch.utils.checkpoint, except allowing for multiple recomputations. However, there a couple things from the implementation that I'm not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.

  1. Why detach the inputs in the backward before recomputing the forward? Would you get something like the gradient of the forward being accumulated twice into the input's gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?
  2. I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs 1 -> drop(2) -> 3, where the module 2 itself performs 2a -> drop(2b). drop is like a higher-order model whose forward simply performs checkpoint(child_model, x). Thus, running 2 should drop 2a and 2b in the forward pass; and in the backward pass, recompute 2a, recompute 2b and drop it, then recompute 2b a second time, this time actually saving it. Obviously, that's not so smart but its a simple example. I believe the following chain of events will occur in practice:

    1. 1's forward is performed, whose output is propagated to drop(2)'s forward function.
    2. drop(2)'s forward invokes CheckpointFunction's forward, which saves (checkpoints) the input, and performs 2's forward without tracking gradient. The intent is that both the outputs of 2a and 2b will be dropped at this stage.
    3. 2a performs its forward and the output is passed to drop(2b)'s forward.
    4. a. If I set up this 'drop' wrapper to not detach its input before passing to the child model, the CheckpointFunction forward will throw an error about none of the inputs have requires_grad=True. b. If I do detach the input, I suspect this will not work either: it will save the input to 2b thus checkpointing it, which, as mentioned in step 3, I do not want to happen! On the other hand, maybe it will be freed because of how everything is set up and autograd's reference counting? Maybe, as drop(2b) saves its input, and 2a was run with no grad, you get an unreachable cycle between them that will get garbage collected? I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.

    If my anaylsis is right, I will have to implement this drop operator from scratch such that it avoids this behaviour, correct?

Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.

@shirazb have you found the answer?

shirazb commented 4 years ago

Hey @knsong,

I updated my answer on the PyTorch forums here:

https://discuss.pytorch.org/t/implementing-multiple-recomputations-on-top-of-torch-utils-checkpoint/52657

knsong commented 4 years ago

@shirazb Did you ever considering the swap in&out method e.g Comprehensive techniques of multi-GPU memory optimizationfor deep learning acceleration using pytorch? Recently I have also been working on improve the training throughput by reducing the memory footprint of training a model. Do you have any suggestion about that?

shirazb commented 4 years ago

@knsong

To be perfectly frank, this was uni work that I haven't looked at for a number of months. I have looked at swapping, but not for a while, and not for PyTorch specifically. I am not familiar with that paper in particular.

I do recall there may have been some work [1] on implementing something similar in PyTorch to TensorFlow Large Model Support (TFLMS) [2], which was a swapping solution for TensorFlow, but I never really looked into it as I was focusing on checkpointing.

--

With regards to your more general question about how to reduce the memory footprint, the single-recomputation O(sqrt(N)) policy already implemented in torch.utils.checkpoint [3] should give significant savings at only a modest cost. See, for example, Chen's paper for experimental results of the same policy implemented in MxNet [4].

For a good, centralised resource that links to many of the other good resources on Deep Learning memory optimisations, I would suggest my thesis [5]. Note it was rather hastily written though so may contain some errors...

I will list some sections below that I think may be of interest to you. Section 5.4 may be of particular interest as it discusses both checkpointing and swapping together (more detail below).

Section 2.5 introduces checkpointing, multiple recomputations, and dynamic programming checkpointing that solves for the optimal policy that minimises compute given a memory constraint. It also briefly discusses the checkpointing approaches in the Automatic Differentiation (AD) community, who have been doing it since the early 90's at least! If you're not familiar with AD, Baydin et al. [6] have an excellent paper that explains what it is and how it is used in ML. Section 2.2 of my thesis gives a much abridged version of this, but I really recommend their paper.

Chapter 3 details the main contributions of my thesis: generalising the aforementioned dynamic programming checkpointing technique to take into account the precise costs of each layer, as the original paper assumes uniform costs. There is a lot of rigmarole in this, but Section 3.2 may be of interest to you, as it discusses how to implement multiple recomputations in PyTorch. The problem is that it may not make sense without reading many of the previous sections of the thesis. Also, it requires a little knowledge of Autograd internals. Section 2.3 gives some introduction to this, which was based off a great video I found on YouTube by Elliot Waite [7].

Section 4.1 explores the compute-memory tradeoff checkpointing provides (with multiple recomputations allowed); and shows that single-recomputation O(sqrt(N)) is pretty good, especially for uniform cost networks, and scales well to very deep networks.

Section 5.2 discusses the limitation of checkpointing to sequences and the difficulties of trying to generalise it to more arbitrary graphs. It mentions a couple recent papers that have made headway in this.

Section 5.4 discusses the role of checkpointing in light of other memory optimisations; that is, CPU-GPU swapping and the clever choice of convolutional operator implementation. It references many papers that will explain them in more detail, but does not give a full survey. Some could have PyTorch implementations (I've forgotten), but most are Proof-of-Concept systems. It also discusses more literature from the AD community on combining swapping and checkpointing that looks very promising. Essentially, they solve a more general checkpointing problem where there is a hierarchy of memory pools, so swapping to CPU memory is simply treated as checkpointing to the second layer of memory.

--

Hope that gave you some ideas, not sure what you were after. Certainly, in terms of swapping techniques already implemented in PyTorch specifically, I am not myself aware of much.

--

Bibliography

[1] TFLMS in PyTorch https://discuss.pytorch.org/t/thoughts-on-use-of-cpu-ram-as-a-swap-for-gpu/13193 [2] TFLMS paper https://arxiv.org/pdf/1807.02037.pdf [3] torch.utils.checkpoint https://pytorch.org/docs/stable/checkpoint.html [4] O(sqrt(N)) Checkpointing paper https://arxiv.org/pdf/1604.06174.pdf [5] My thesis - 'Precise Optimal Checkpointing for Memory-Efficient Deep Learning' https://github.com/shirazb/final-year-thesis/tree/master/out [6] AD in ML Survery Paper https://arxiv.org/pdf/1502.05767.pdf [7] Autograd Internals YouTube Video https://www.youtube.com/watch?v=MswxJw-8PvE