Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 80 forks source link

Thunder's horizontal fusion is memory inefficient for backward functions with activation checkpointing #1337

Open IvanYashchuk opened 1 month ago

IvanYashchuk commented 1 month ago

TL;DR: Thunder's fusion pass needs to change to consider the memory usage of the operations and the intermediate tensors. It should avoid fusing operations that increase peak memory usage. Use memory_peak_efficient_func as the target function for optimization.

Let's take a look at the following code snippet:

import torch
import gc

def memory_peak_efficient_func(t0s, a):
    for t0 in t0s:
        t1 = torch.nn.functional.relu(t0); del t0
        t2 = torch.matmul(t1, t1); del t1
        t3 = torch.nn.functional.relu(t2); del t2
        a = torch.matmul(t3, a); del t3
    return a

N_PARALLEL_PATHS = 10
t0s = [torch.randn(256, 256, device="cuda") for _ in range(N_PARALLEL_PATHS)] # 0.25 MiB * N_PARALLEL_PATHS
a = torch.randn(256, 256, device="cuda") # 0.25 MiB
memory_peak_efficient_func(t0s, a)

# Record peak memory usage
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
memory_peak_efficient_func(t0s, a)
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f"Peak memory usage diff: {max_allocated_in_MiB:.2f} MiB")
Peak memory usage diff: 0.75 MiB

The code snippet above is a simple example of a memory-efficient function that processes a list of tensors sequentially. The function avoids unnecessary memory allocations due to the order of operations and deleting intermediate tensors after they are no longer needed.

The code can be further "optimized" by fusing some operations into a single region. Before fusion let's take a look at the differently structured code snippet below:

import torch
import gc

def memory_peak_inefficient_func(t0s, a):
    # Can be fused into a single region
    t1s = []
    for t0 in t0s:
        t1 = torch.nn.functional.relu(t0); del t0
        t1s.append(t1)
    del t0s

    t2s = []
    while t1s:
        t1 = t1s.pop()
        t2 = torch.matmul(t1, t1); del t1
        t2s.append(t2)
    del t1s

    # Can be fused into a single region
    t3s = []
    while t2s:
        t2 = t2s.pop()
        t3 = torch.nn.functional.relu(t2); del t2
        t3s.append(t3)
    del t2s

    while t3s:
        t3 = t3s.pop()
        a = torch.matmul(t3, a); del t3
    del t3s

    return a

N_PARALLEL_PATHS = 10
t0s = [torch.randn(256, 256, device="cuda") for _ in range(N_PARALLEL_PATHS)] # 0.25 MiB * N_PARALLEL_PATHS
a = torch.randn(256, 256, device="cuda") # 0.25 MiB
memory_peak_inefficient_func(t0s, a)

# Record peak memory usage
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
memory_peak_inefficient_func(t0s, a)
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f"Peak memory usage diff: {max_allocated_in_MiB:.2f} MiB")
Peak memory usage diff: 2.75 MiB

The code snippet above is a modified version of the previous code snippet computing the same result. It precomputes intermediate tensors and stores them in lists to be used later. This version of the code is less memory-efficient because it stores intermediate tensors in memory, which increases peak memory usage.

Let's apply Thunder on the inefficient code snippet to see if it can optimize the memory usage.

import thunder

jit_memory_peak_inefficient_func = thunder.jit(memory_peak_inefficient_func)
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
jit_memory_peak_inefficient_func(t0s, a)
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f"Peak memory usage diff: {max_allocated_in_MiB:.2f} MiB")
Peak memory usage diff: 2.25 MiB

Thunder was able to optimize the memory usage of the inefficient code snippet by fusing the operations into a single region. The peak memory usage decreased from 2.75 MiB to 2.25 MiB, but it is still higher than the memory-efficient version of the code.

What would happen if we apply Thunder on the memory-efficient code snippet?

import thunder

jit_memory_peak_efficient_func = thunder.jit(memory_peak_efficient_func)
jit_memory_peak_efficient_func(t0s, a)
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
jit_memory_peak_efficient_func(t0s, a)
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f"Peak memory usage diff: {max_allocated_in_MiB:.2f} MiB")
Peak memory usage diff: 2.25 MiB

The same memory usage as the inefficient code snippet! Thunder was not able to optimize the memory usage of the memory-efficient code snippet because it applies topological sorting to the computation graph preferring forming as big fusion groups as possible. In this case, the memory-efficient code snippet already has the optimal order of operations, and Thunder breaks it by fusing operations into a single region.

Here's the dataflow graph of the execution trace of the memory-efficient code snippet: ```py from thunder.core.transform_common import unwrap_return_value from thunder.examine import make_trace_dot t = unwrap_return_value(thunder.last_traces(jit_memory_peak_efficient_func)[-3]) dot = make_trace_dot(t) ``` image
And here's how it looked before the horizontal fusion: ```py from thunder.examine import make_trace_dot t = thunder.last_traces(jit_memory_peak_efficient_func)[0] dot = make_trace_dot(t) ``` ![image](https://github.com/user-attachments/assets/d2b9afbd-7794-4f97-b47e-8e094b26a8d5) A lot more freedom in terms of grouping and reordering of operations as there are many parallel paths.

Thunder's fusion pass needs to change to consider the memory usage of the operations and the intermediate tensors. It should avoid fusing operations that increase peak memory usage. Here are the lines of code where the main logic for grouping operations is implemented:

https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/executors/data_dependent_partition.py#L299-L303

We need to resolve this issue because it leads to memory inefficiency in the generated code with activation checkpointing applied. The memory_peak_efficient_func should be the target function for optimization because the same pattern can be found in activation-checkpointed backward functions.

This problem was discovered with Yan's work on enabling PyTorch-native activation checkpointing in https://github.com/Lightning-AI/lightning-thunder/pull/1261. Here are last backward traces https://gist.github.com/kiya00/3ae4890e1ae5abf442d475cccadaa9ec#file-ckp_longchat-7b-16k_traces-py-L2116-L2117.

cc @apaz-cli

IvanYashchuk commented 1 month ago

@kiya00, @riccardofelluga: could you please work together on resolving this issue? Let me know if you need any support. Thanks!

kiya00 commented 1 month ago

Hi @riccardofelluga I see you self-assigned, do you think it could be split up, and I can work on a piece, or anything I can help?

riccardofelluga commented 1 month ago

Yes we should split it up! I've just assigned it so I can keep track of it

kiya00 commented 1 month ago

Hi, I noticed the current last_trace of the memory_peak_efficient_func is:

10 input tensors # t0-t9
10 outputs = nvfusion0(t0, ..., t9) # has 10 Relu fused together
10 matmul take the previous 10 outputs one by one
return [ret0, ..., ret9]

if we want memory efficiency, in this case it means like:

10 input tensors # t0-t9
out0=nvfusion0(t0) # only 1 ReLU is fused
ret0 = matmul(out0); del out0
out1=nvfusion1(t1) # only 1 ReLU is fused
ret1 = matmul(out1); del out1
... # repeat for 10 inputs
return [ret0, ..., ret9]

But that will cause the fusion part to split, am I understanding it right?

IvanYashchuk commented 1 month ago

Yes, that's the basic idea. The fusions should be split into individual ones if the computations inside are independent of each other.

kiya00 commented 1 month ago

I think the easiest way could be to implement a "vertical" merge, otherwise we probably need to investigate if there are other merge algorithms that can balance the memory efficiency and the merge range.

kiya00 commented 1 month ago

Updating some experimental results: IIUC, the dataflow_merge does the "vertical" fusion inplace on Graph, the horizontal_merge does the topo sort (bfs with indegree) and horizontally fuses the ops along the way. So I did some experiment skip the horizontal_merge and add a dfs topo sort after dataflow_merge (branch)

the trace w/ horizontal merge ``` @torch.no_grad() @no_autocast def computation(t0, t_0_1, t_0_2, t_0_3, t_0_4, t_0_5, t_0_6, t_0_7, t_0_8, t_0_9, a): # t0: "cuda:0 f32[256, 256]" # t_0_1: "cuda:0 f32[256, 256]" # t_0_2: "cuda:0 f32[256, 256]" # t_0_3: "cuda:0 f32[256, 256]" # t_0_4: "cuda:0 f32[256, 256]" # t_0_5: "cuda:0 f32[256, 256]" # t_0_6: "cuda:0 f32[256, 256]" # t_0_7: "cuda:0 f32[256, 256]" # t_0_8: "cuda:0 f32[256, 256]" # t_0_9: "cuda:0 f32[256, 256]" # a: "cuda:0 f32[256, 256]" [t2, t8, t14, t20, t26, t32, t38, t44, t50, t56] = nvFusion0(t0, t_0_1, t_0_2, t_0_3, t_0_4, t_0_5, t_0_6, t_0_7, t_0_8, t_0_9) # t1 = prims.gt(t0, 0.0) # t1: "cuda:0 b8[256, 256]" # t2 = prims.where(t1, t0, 0.0) # t2: "cuda:0 f32[256, 256]" # t7 = prims.gt(t_0_1, 0.0) # t7: "cuda:0 b8[256, 256]" # t8 = prims.where(t7, t_0_1, 0.0) # t8: "cuda:0 f32[256, 256]" # t13 = prims.gt(t_0_2, 0.0) # t13: "cuda:0 b8[256, 256]" # t14 = prims.where(t13, t_0_2, 0.0) # t14: "cuda:0 f32[256, 256]" # t19 = prims.gt(t_0_3, 0.0) # t19: "cuda:0 b8[256, 256]" # t20 = prims.where(t19, t_0_3, 0.0) # t20: "cuda:0 f32[256, 256]" # t25 = prims.gt(t_0_4, 0.0) # t25: "cuda:0 b8[256, 256]" # t26 = prims.where(t25, t_0_4, 0.0) # t26: "cuda:0 f32[256, 256]" # t31 = prims.gt(t_0_5, 0.0) # t31: "cuda:0 b8[256, 256]" # t32 = prims.where(t31, t_0_5, 0.0) # t32: "cuda:0 f32[256, 256]" # t37 = prims.gt(t_0_6, 0.0) # t37: "cuda:0 b8[256, 256]" # t38 = prims.where(t37, t_0_6, 0.0) # t38: "cuda:0 f32[256, 256]" # t43 = prims.gt(t_0_7, 0.0) # t43: "cuda:0 b8[256, 256]" # t44 = prims.where(t43, t_0_7, 0.0) # t44: "cuda:0 f32[256, 256]" # t49 = prims.gt(t_0_8, 0.0) # t49: "cuda:0 b8[256, 256]" # t50 = prims.where(t49, t_0_8, 0.0) # t50: "cuda:0 f32[256, 256]" # t55 = prims.gt(t_0_9, 0.0) # t55: "cuda:0 b8[256, 256]" # t56 = prims.where(t55, t_0_9, 0.0) # t56: "cuda:0 f32[256, 256]" t3 = torch.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" # t3 = ltorch.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" # t3 = prims.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" del t2 t9 = torch.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" # t9 = ltorch.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" # t9 = prims.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" del t8 t15 = torch.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" # t15 = ltorch.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" # t15 = prims.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" del t14 t21 = torch.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" # t21 = ltorch.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" # t21 = prims.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" del t20 t27 = torch.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" # t27 = ltorch.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" # t27 = prims.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" del t26 t33 = torch.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" # t33 = ltorch.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" # t33 = prims.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" del t32 t39 = torch.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" # t39 = ltorch.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" # t39 = prims.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" del t38 t45 = torch.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" # t45 = ltorch.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" # t45 = prims.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" del t44 t51 = torch.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" # t51 = ltorch.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" # t51 = prims.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" del t50 t57 = torch.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" # t57 = ltorch.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" # t57 = prims.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" del t56 [t5, t11, t17, t23, t29, t35, t41, t47, t53, t59] = nvFusion1(t3, t9, t15, t21, t27, t33, t39, t45, t51, t57) # t4 = prims.gt(t3, 0.0) # t4: "cuda:0 b8[256, 256]" # t5 = prims.where(t4, t3, 0.0) # t5: "cuda:0 f32[256, 256]" # t10 = prims.gt(t9, 0.0) # t10: "cuda:0 b8[256, 256]" # t11 = prims.where(t10, t9, 0.0) # t11: "cuda:0 f32[256, 256]" # t16 = prims.gt(t15, 0.0) # t16: "cuda:0 b8[256, 256]" # t17 = prims.where(t16, t15, 0.0) # t17: "cuda:0 f32[256, 256]" # t22 = prims.gt(t21, 0.0) # t22: "cuda:0 b8[256, 256]" # t23 = prims.where(t22, t21, 0.0) # t23: "cuda:0 f32[256, 256]" # t28 = prims.gt(t27, 0.0) # t28: "cuda:0 b8[256, 256]" # t29 = prims.where(t28, t27, 0.0) # t29: "cuda:0 f32[256, 256]" # t34 = prims.gt(t33, 0.0) # t34: "cuda:0 b8[256, 256]" # t35 = prims.where(t34, t33, 0.0) # t35: "cuda:0 f32[256, 256]" # t40 = prims.gt(t39, 0.0) # t40: "cuda:0 b8[256, 256]" # t41 = prims.where(t40, t39, 0.0) # t41: "cuda:0 f32[256, 256]" # t46 = prims.gt(t45, 0.0) # t46: "cuda:0 b8[256, 256]" # t47 = prims.where(t46, t45, 0.0) # t47: "cuda:0 f32[256, 256]" # t52 = prims.gt(t51, 0.0) # t52: "cuda:0 b8[256, 256]" # t53 = prims.where(t52, t51, 0.0) # t53: "cuda:0 f32[256, 256]" # t58 = prims.gt(t57, 0.0) # t58: "cuda:0 b8[256, 256]" # t59 = prims.where(t58, t57, 0.0) # t59: "cuda:0 f32[256, 256]" del t3, t9, t15, t21, t27, t33, t39, t45, t51, t57 t6 = torch.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" # t6 = ltorch.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" # t6 = prims.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" del t5 t12 = torch.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" # t12 = ltorch.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" # t12 = prims.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" del t11, t6 t18 = torch.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" # t18 = ltorch.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" # t18 = prims.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" del t17, t12 t24 = torch.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" # t24 = ltorch.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" # t24 = prims.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" del t23, t18 t30 = torch.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" # t30 = ltorch.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" # t30 = prims.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" del t29, t24 t36 = torch.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" # t36 = ltorch.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" # t36 = prims.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" del t35, t30 t42 = torch.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" # t42 = ltorch.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" # t42 = prims.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" del t41, t36 t48 = torch.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" # t48 = ltorch.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" # t48 = prims.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" del t47, t42 t54 = torch.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" # t54 = ltorch.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" # t54 = prims.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" del t53, t48 t60 = torch.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" # t60 = ltorch.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" # t60 = prims.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" del t59, t54 return t60 ```
the trace w/o horizontal merge(with dfs topo sort) ``` @torch.no_grad() @no_autocast def computation(t0, t_0_1, t_0_2, t_0_3, t_0_4, t_0_5, t_0_6, t_0_7, t_0_8, t_0_9, a): # t0: "cuda:0 f32[256, 256]" [t2] = nvFusion0(t0) # t1 = prims.gt(t0, 0.0) # t1: "cuda:0 b8[256, 256]" # t2 = prims.where(t1, t0, 0.0) # t2: "cuda:0 f32[256, 256]" t3 = torch.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" # t3 = ltorch.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" # t3 = prims.matmul(t2, t2) # t3: "cuda:0 f32[256, 256]" del t2 [t5] = nvFusion1(t3) # t4 = prims.gt(t3, 0.0) # t4: "cuda:0 b8[256, 256]" # t5 = prims.where(t4, t3, 0.0) # t5: "cuda:0 f32[256, 256]" del t3 # t_0_1: "cuda:0 f32[256, 256]" [t8] = nvFusion2(t_0_1) # t7 = prims.gt(t_0_1, 0.0) # t7: "cuda:0 b8[256, 256]" # t8 = prims.where(t7, t_0_1, 0.0) # t8: "cuda:0 f32[256, 256]" t9 = torch.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" # t9 = ltorch.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" # t9 = prims.matmul(t8, t8) # t9: "cuda:0 f32[256, 256]" del t8 [t11] = nvFusion3(t9) # t10 = prims.gt(t9, 0.0) # t10: "cuda:0 b8[256, 256]" # t11 = prims.where(t10, t9, 0.0) # t11: "cuda:0 f32[256, 256]" del t9 # t_0_2: "cuda:0 f32[256, 256]" [t14] = nvFusion4(t_0_2) # t13 = prims.gt(t_0_2, 0.0) # t13: "cuda:0 b8[256, 256]" # t14 = prims.where(t13, t_0_2, 0.0) # t14: "cuda:0 f32[256, 256]" t15 = torch.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" # t15 = ltorch.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" # t15 = prims.matmul(t14, t14) # t15: "cuda:0 f32[256, 256]" del t14 [t17] = nvFusion5(t15) # t16 = prims.gt(t15, 0.0) # t16: "cuda:0 b8[256, 256]" # t17 = prims.where(t16, t15, 0.0) # t17: "cuda:0 f32[256, 256]" del t15 # t_0_3: "cuda:0 f32[256, 256]" [t20] = nvFusion6(t_0_3) # t19 = prims.gt(t_0_3, 0.0) # t19: "cuda:0 b8[256, 256]" # t20 = prims.where(t19, t_0_3, 0.0) # t20: "cuda:0 f32[256, 256]" t21 = torch.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" # t21 = ltorch.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" # t21 = prims.matmul(t20, t20) # t21: "cuda:0 f32[256, 256]" del t20 [t23] = nvFusion7(t21) # t22 = prims.gt(t21, 0.0) # t22: "cuda:0 b8[256, 256]" # t23 = prims.where(t22, t21, 0.0) # t23: "cuda:0 f32[256, 256]" del t21 # t_0_4: "cuda:0 f32[256, 256]" [t26] = nvFusion8(t_0_4) # t25 = prims.gt(t_0_4, 0.0) # t25: "cuda:0 b8[256, 256]" # t26 = prims.where(t25, t_0_4, 0.0) # t26: "cuda:0 f32[256, 256]" t27 = torch.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" # t27 = ltorch.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" # t27 = prims.matmul(t26, t26) # t27: "cuda:0 f32[256, 256]" del t26 [t29] = nvFusion9(t27) # t28 = prims.gt(t27, 0.0) # t28: "cuda:0 b8[256, 256]" # t29 = prims.where(t28, t27, 0.0) # t29: "cuda:0 f32[256, 256]" del t27 # t_0_5: "cuda:0 f32[256, 256]" [t32] = nvFusion10(t_0_5) # t31 = prims.gt(t_0_5, 0.0) # t31: "cuda:0 b8[256, 256]" # t32 = prims.where(t31, t_0_5, 0.0) # t32: "cuda:0 f32[256, 256]" t33 = torch.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" # t33 = ltorch.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" # t33 = prims.matmul(t32, t32) # t33: "cuda:0 f32[256, 256]" del t32 [t35] = nvFusion11(t33) # t34 = prims.gt(t33, 0.0) # t34: "cuda:0 b8[256, 256]" # t35 = prims.where(t34, t33, 0.0) # t35: "cuda:0 f32[256, 256]" del t33 # t_0_6: "cuda:0 f32[256, 256]" [t38] = nvFusion12(t_0_6) # t37 = prims.gt(t_0_6, 0.0) # t37: "cuda:0 b8[256, 256]" # t38 = prims.where(t37, t_0_6, 0.0) # t38: "cuda:0 f32[256, 256]" t39 = torch.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" # t39 = ltorch.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" # t39 = prims.matmul(t38, t38) # t39: "cuda:0 f32[256, 256]" del t38 [t41] = nvFusion13(t39) # t40 = prims.gt(t39, 0.0) # t40: "cuda:0 b8[256, 256]" # t41 = prims.where(t40, t39, 0.0) # t41: "cuda:0 f32[256, 256]" del t39 # t_0_7: "cuda:0 f32[256, 256]" [t44] = nvFusion14(t_0_7) # t43 = prims.gt(t_0_7, 0.0) # t43: "cuda:0 b8[256, 256]" # t44 = prims.where(t43, t_0_7, 0.0) # t44: "cuda:0 f32[256, 256]" t45 = torch.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" # t45 = ltorch.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" # t45 = prims.matmul(t44, t44) # t45: "cuda:0 f32[256, 256]" del t44 [t47] = nvFusion15(t45) # t46 = prims.gt(t45, 0.0) # t46: "cuda:0 b8[256, 256]" # t47 = prims.where(t46, t45, 0.0) # t47: "cuda:0 f32[256, 256]" del t45 # t_0_8: "cuda:0 f32[256, 256]" [t50] = nvFusion16(t_0_8) # t49 = prims.gt(t_0_8, 0.0) # t49: "cuda:0 b8[256, 256]" # t50 = prims.where(t49, t_0_8, 0.0) # t50: "cuda:0 f32[256, 256]" t51 = torch.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" # t51 = ltorch.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" # t51 = prims.matmul(t50, t50) # t51: "cuda:0 f32[256, 256]" del t50 [t53] = nvFusion17(t51) # t52 = prims.gt(t51, 0.0) # t52: "cuda:0 b8[256, 256]" # t53 = prims.where(t52, t51, 0.0) # t53: "cuda:0 f32[256, 256]" del t51 # t_0_9: "cuda:0 f32[256, 256]" [t56] = nvFusion18(t_0_9) # t55 = prims.gt(t_0_9, 0.0) # t55: "cuda:0 b8[256, 256]" # t56 = prims.where(t55, t_0_9, 0.0) # t56: "cuda:0 f32[256, 256]" t57 = torch.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" # t57 = ltorch.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" # t57 = prims.matmul(t56, t56) # t57: "cuda:0 f32[256, 256]" del t56 [t59] = nvFusion19(t57) # t58 = prims.gt(t57, 0.0) # t58: "cuda:0 b8[256, 256]" # t59 = prims.where(t58, t57, 0.0) # t59: "cuda:0 f32[256, 256]" del t57 # a: "cuda:0 f32[256, 256]" t6 = torch.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" # t6 = ltorch.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" # t6 = prims.matmul(t5, a) # t6: "cuda:0 f32[256, 256]" del t5 t12 = torch.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" # t12 = ltorch.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" # t12 = prims.matmul(t11, t6) # t12: "cuda:0 f32[256, 256]" del t11, t6 t18 = torch.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" # t18 = ltorch.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" # t18 = prims.matmul(t17, t12) # t18: "cuda:0 f32[256, 256]" del t17, t12 t24 = torch.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" # t24 = ltorch.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" # t24 = prims.matmul(t23, t18) # t24: "cuda:0 f32[256, 256]" del t23, t18 t30 = torch.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" # t30 = ltorch.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" # t30 = prims.matmul(t29, t24) # t30: "cuda:0 f32[256, 256]" del t29, t24 t36 = torch.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" # t36 = ltorch.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" # t36 = prims.matmul(t35, t30) # t36: "cuda:0 f32[256, 256]" del t35, t30 t42 = torch.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" # t42 = ltorch.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" # t42 = prims.matmul(t41, t36) # t42: "cuda:0 f32[256, 256]" del t41, t36 t48 = torch.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" # t48 = ltorch.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" # t48 = prims.matmul(t47, t42) # t48: "cuda:0 f32[256, 256]" del t47, t42 t54 = torch.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" # t54 = ltorch.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" # t54 = prims.matmul(t53, t48) # t54: "cuda:0 f32[256, 256]" del t53, t48 t60 = torch.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" # t60 = ltorch.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" # t60 = prims.matmul(t59, t54) # t60: "cuda:0 f32[256, 256]" del t59, t54 return t60 ```

but when I tried on the longchat-7b-16k, although the sequence of cudnn_sdpa_fwd get separated, the peak memory is even larger, this is the last backward trace: https://gist.github.com/kiya00/4d906a424e0adc7640d3011a0375b5cb

IvanYashchuk commented 1 month ago

when I tried on the longchat-7b-16k, although the sequence of cudnn_sdpa_fwd get separated, the peak memory is even larger, this is the last backward trace: https://gist.github.com/kiya00/4d906a424e0adc7640d3011a0375b5cb

There are quite many operations between the time when t1140 and t1141 from the recomputed forward piece are actually used. Is there anything in terms of dataflow preventing these tensors be placed closer to the backward consumer?

kiya00 commented 1 month ago

In the actually used , t1140 and t1141 are hold to wait for t1165 is ready, there's a lone path from t583 until t1165 is calculated. these 3 tensors need to wait each other somehow, I can try maybe let the smaller tensor wait for the larger tensor to be calculated?

kiya00 commented 1 month ago

I tried a bottom-up dfs and meanwhile choose the smaller output tensor first(trace): t1141, t1140, they are used here, t1165 is hold to wait for calculating t1140 t1141, but the memory doesn't change much(the method before uses 12.63 GB, now it decreases a bit to 12.45GB ). here are some results, maybe breaking the fusion region causes more memory usage: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

  | dfs experiment | w/ horizontal merge(original) -- | -- | -- vicuna-33b-v1.3 | 12.29 GB | 12.29 GB falcon-40b | 19.98 GB | 19.98 GB Gemma-7b | 21.29 GB | 21.31 GB longchat-7b-16k | 12.45 GB | 12.36 GB pythia-12b | 9.65 GB | 9.65 GB

mruberry commented 1 month ago

@kiya00 @jjsjann123 @IvanYashchuk Would it be possible to let a practitioner or developer control the algorithm used with a developer option? Like maybe the developer option can accept a function whose signature is not guaranteed (but defined somewhere). That could facilitate exploration. We could also think about adding different variants of the algorithm to the repository (maybe in an "experimental" directory) or to a different repository if people want to try them or tweak them.

IvanYashchuk commented 1 month ago

Yes, it will be possible to control that with an option.

IvanYashchuk commented 1 month ago

In the actually used , t1140 and t1141 are hold to wait for t1165 is ready, there's a lone path from t583 until t1165 is calculated. these 3 tensors need to wait each other somehow, I can try maybe let the smaller tensor wait for the larger tensor to be calculated?

The fusion pass is now correctly not blocking the dataflow and it's possible to have an ordering that results in a more efficient peak memory usage. The next step is to add a special sorting of bound symbols that corresponds to a function with better memory usage. Here's the current dataflow graph with Yan's patch on the example function from the issue description: image

kiya00 commented 1 month ago

I tried to add a memory_efficient_sorting for our toposort_bsym_dag, the jitted function of memory_peak_efficient/inefficient_func now uses 0.75MB, but when I tried it on longchat-7b-16k, the peak memory is only slightly reduced to 12.19GB, compared to the original 12.45GB

IvanYashchuk commented 1 month ago

the jitted function of memory_peak_efficient/inefficient_func now uses 0.75 MiB

Awesome! What's the algorithm?

when I tried it on longchat-7b-16k, the peak memory is only slightly reduced to 12.19GB, compared to the original 12.45GB

What's the expected target? Can it be lower than 12.19GB?

kiya00 commented 1 month ago

Awesome! What's the algorithm?

I tried to set the selector in toposort_bsym_dag, bottom up, always choose the bsym that allocates the maximum memory in the topologically equal node list. Assume the output is allocated by the current bsym, and if the input has no other consumer than the current bsym, it is freed.

What's the expected target? Can it be lower than 12.19GB?

Based on the converter branch, before applying the changes in branch(trytofix1337), the memory usage of longchat-7b-16k(python thunder/benchmarks/benchmark_litgpt.py --model_name longchat-7b-16k --compile dynamo+thunder --n_layers=2 --max_iters=2 --warmup_iters=1 --checkpoint_activations=True) is 12.36GB, after is 12.19GB Inductor costs 12.09 GB, eager costs 10.55GB

IvanYashchuk commented 4 weeks ago

@kiya00 noticed that torch.compile+FSDP+activation checkpointing doesn't hit this fusion and execution order problem because the graph is broken into smaller pieces by Dynamo. https://github.com/Lightning-AI/lightning-thunder/pull/1370 should be enough to enable benchmarks to work in the multi-GPU case.