Open wangli1426 opened 5 years ago
The high level goal is to find nodes which separate the graph into two parts
From https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
A bottleneck is a graph separate which is a single node, and that line is a heuristic which tries to find it.
The first part of heuristic looks at ops when walking forward and backward from the point, and checking if there's an overlap. No overlap suggests that x3 is a separator in a diagram below.
set(b_inp).intersection(f_inp)
The second part, len(b_inp)+len(f_inp) >= len(ts_all)
, was added to make it work in an edge case when backward and forward traverse didn't capture all the tensors, in which case we abort and don't try to count it as a separator/bottleneck node. I don't remember in which case that happened, but it was necessary to make tests pass, so one of the tests triggers this condition.
Thanks for the explanation. It is very clear and helpful.
May I ask what's the purpose of the following code block? Thanks in advance.
It saves new ops added to the graph inside this block to bwd_ops
Got it. Thanks.
I am trying to understand the heuristic algorithm used in the
memory
policy. However I could not fully understand the whole logic, especially the followingif statement
as shown below.https://github.com/cybertronai/gradient-checkpointing/blob/43444e0523495c9f305f2c32d81eeea2328a1b18/memory_saving_gradients.py#L143
Some explanations or guidance will be highly appreciated.
Thanks.