cybertronai / gradient-checkpointing

Make huge neural nets fit in memory
MIT License
2.73k stars 271 forks source link

Some high level descriptions on the MEMORY policy will be very helpful #46

Open wangli1426 opened 5 years ago

wangli1426 commented 5 years ago

I am trying to understand the heuristic algorithm used in the memory policy. However I could not fully understand the whole logic, especially the following if statement as shown below.

https://github.com/cybertronai/gradient-checkpointing/blob/43444e0523495c9f305f2c32d81eeea2328a1b18/memory_saving_gradients.py#L143

Some explanations or guidance will be highly appreciated.

Thanks.

yaroslavvb commented 5 years ago

The high level goal is to find nodes which separate the graph into two parts

From https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9

0_iBtm1m0IFhG9d7HE

A bottleneck is a graph separate which is a single node, and that line is a heuristic which tries to find it.

The first part of heuristic looks at ops when walking forward and backward from the point, and checking if there's an overlap. No overlap suggests that x3 is a separator in a diagram below.

set(b_inp).intersection(f_inp)

Screenshot 2019-07-30 07 24 17

The second part, len(b_inp)+len(f_inp) >= len(ts_all) , was added to make it work in an edge case when backward and forward traverse didn't capture all the tensors, in which case we abort and don't try to count it as a separator/bottleneck node. I don't remember in which case that happened, but it was necessary to make tests pass, so one of the tests triggers this condition.

wangli1426 commented 5 years ago

Thanks for the explanation. It is very clear and helpful.

May I ask what's the purpose of the following code block? Thanks in advance.

https://github.com/cybertronai/gradient-checkpointing/blob/43444e0523495c9f305f2c32d81eeea2328a1b18/memory_saving_gradients.py#L123

yaroslavvb commented 5 years ago

It saves new ops added to the graph inside this block to bwd_ops

wangli1426 commented 5 years ago

Got it. Thanks.