Open IvanYashchuk opened 1 month ago
@kiya00, @riccardofelluga: could you please work together on resolving this issue? Let me know if you need any support. Thanks!
Hi @riccardofelluga I see you self-assigned, do you think it could be split up, and I can work on a piece, or anything I can help?
Yes we should split it up! I've just assigned it so I can keep track of it
Hi, I noticed the current last_trace of the memory_peak_efficient_func
is:
10 input tensors # t0-t9
10 outputs = nvfusion0(t0, ..., t9) # has 10 Relu fused together
10 matmul take the previous 10 outputs one by one
return [ret0, ..., ret9]
if we want memory efficiency, in this case it means like:
10 input tensors # t0-t9
out0=nvfusion0(t0) # only 1 ReLU is fused
ret0 = matmul(out0); del out0
out1=nvfusion1(t1) # only 1 ReLU is fused
ret1 = matmul(out1); del out1
... # repeat for 10 inputs
return [ret0, ..., ret9]
But that will cause the fusion part to split, am I understanding it right?
Yes, that's the basic idea. The fusions should be split into individual ones if the computations inside are independent of each other.
I think the easiest way could be to implement a "vertical" merge, otherwise we probably need to investigate if there are other merge algorithms that can balance the memory efficiency and the merge range.
Updating some experimental results:
IIUC, the dataflow_merge
does the "vertical" fusion inplace on Graph, the horizontal_merge
does the topo sort (bfs with indegree) and horizontally fuses the ops along the way. So I did some experiment skip the horizontal_merge
and add a dfs topo sort after dataflow_merge
(branch)
but when I tried on the longchat-7b-16k
, although the sequence of cudnn_sdpa_fwd
get separated, the peak memory is even larger, this is the last backward trace: https://gist.github.com/kiya00/4d906a424e0adc7640d3011a0375b5cb
when I tried on the longchat-7b-16k, although the sequence of cudnn_sdpa_fwd get separated, the peak memory is even larger, this is the last backward trace: https://gist.github.com/kiya00/4d906a424e0adc7640d3011a0375b5cb
There are quite many operations between the time when t1140
and t1141
from the recomputed forward piece are actually used. Is there anything in terms of dataflow preventing these tensors be placed closer to the backward consumer?
In the actually used , t1140
and t1141
are hold to wait for t1165
is ready, there's a lone path from t583 until t1165 is calculated.
these 3 tensors need to wait each other somehow, I can try maybe let the smaller tensor wait for the larger tensor to be calculated?
I tried a bottom-up dfs and meanwhile choose the smaller output tensor first(trace): t1141, t1140, they are used here, t1165 is hold to wait for calculating t1140 t1141, but the memory doesn't change much(the method before uses 12.63 GB, now it decreases a bit to 12.45GB ). here are some results, maybe breaking the fusion region causes more memory usage: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">
| dfs experiment | w/ horizontal merge(original) -- | -- | -- vicuna-33b-v1.3 | 12.29 GB | 12.29 GB falcon-40b | 19.98 GB | 19.98 GB Gemma-7b | 21.29 GB | 21.31 GB longchat-7b-16k | 12.45 GB | 12.36 GB pythia-12b | 9.65 GB | 9.65 GB
TL;DR: Thunder's fusion pass needs to change to consider the memory usage of the operations and the intermediate tensors. It should avoid fusing operations that increase peak memory usage. Use
memory_peak_efficient_func
as the target function for optimization.Let's take a look at the following code snippet:
The code snippet above is a simple example of a memory-efficient function that processes a list of tensors sequentially. The function avoids unnecessary memory allocations due to the order of operations and deleting intermediate tensors after they are no longer needed.
The code can be further "optimized" by fusing some operations into a single region. Before fusion let's take a look at the differently structured code snippet below:
The code snippet above is a modified version of the previous code snippet computing the same result. It precomputes intermediate tensors and stores them in lists to be used later. This version of the code is less memory-efficient because it stores intermediate tensors in memory, which increases peak memory usage.
Let's apply Thunder on the inefficient code snippet to see if it can optimize the memory usage.
Thunder was able to optimize the memory usage of the inefficient code snippet by fusing the operations into a single region. The peak memory usage decreased from 2.75 MiB to 2.25 MiB, but it is still higher than the memory-efficient version of the code.
What would happen if we apply Thunder on the memory-efficient code snippet?
The same memory usage as the inefficient code snippet! Thunder was not able to optimize the memory usage of the memory-efficient code snippet because it applies topological sorting to the computation graph preferring forming as big fusion groups as possible. In this case, the memory-efficient code snippet already has the optimal order of operations, and Thunder breaks it by fusing operations into a single region.
Here's the dataflow graph of the execution trace of the memory-efficient code snippet:
```py from thunder.core.transform_common import unwrap_return_value from thunder.examine import make_trace_dot t = unwrap_return_value(thunder.last_traces(jit_memory_peak_efficient_func)[-3]) dot = make_trace_dot(t) ```And here's how it looked before the horizontal fusion:
```py from thunder.examine import make_trace_dot t = thunder.last_traces(jit_memory_peak_efficient_func)[0] dot = make_trace_dot(t) ``` ![image](https://github.com/user-attachments/assets/d2b9afbd-7794-4f97-b47e-8e094b26a8d5) A lot more freedom in terms of grouping and reordering of operations as there are many parallel paths.Thunder's fusion pass needs to change to consider the memory usage of the operations and the intermediate tensors. It should avoid fusing operations that increase peak memory usage. Here are the lines of code where the main logic for grouping operations is implemented:
https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/executors/data_dependent_partition.py#L299-L303
We need to resolve this issue because it leads to memory inefficiency in the generated code with activation checkpointing applied. The
memory_peak_efficient_func
should be the target function for optimization because the same pattern can be found in activation-checkpointed backward functions.This problem was discovered with Yan's work on enabling PyTorch-native activation checkpointing in https://github.com/Lightning-AI/lightning-thunder/pull/1261. Here are last backward traces https://gist.github.com/kiya00/3ae4890e1ae5abf442d475cccadaa9ec#file-ckp_longchat-7b-16k_traces-py-L2116-L2117.
cc @apaz-cli