Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 80 forks source link

Thunder saves too many tensors for backward for a Transformer's residual connection pattern #1368

Open IvanYashchuk opened 1 month ago

IvanYashchuk commented 1 month ago

🐛 Bug

To Reproduce

Here's a problematic pattern where Thunder's rematerialization algorithm is not effective:

from torch.nn import Linear
import torch
import thunder

block_size = 16384
n_embd = 2560
intermediate_size = 10240
with torch.device("cuda"):
    fc = Linear(n_embd, intermediate_size, bias=False).to(torch.float16)
    proj = Linear(intermediate_size, n_embd, bias=False).to(torch.float16)
fc_weight = fc.weight
proj_weight = proj.weight

def mlp(x: torch.Tensor, fc_weight, proj_weight) -> torch.Tensor:
    x = torch.nn.functional.linear(x, fc_weight)
    x = torch.exp(x) #torch.nn.functional.gelu(x, approximate="none")
    return torch.nn.functional.linear(x, proj_weight)

def f(x, fc_weight, proj_weight):
    # start of transformer block
    x_normed = torch.exp(x)
    attention_output = x_normed

    x = attention_output + x
    nx = torch.exp(x)
    x = mlp(nx, fc_weight, proj_weight) + x
    # end of transformer block
    x = torch.exp(x)
    return x

tf = torch.compile(f)
jf = thunder.jit(f)

x = torch.randn(1, block_size, n_embd, device="cuda", requires_grad=True, dtype=torch.float16)

tout = tf(x, fc_weight, proj_weight)
print(f"{len(tout.grad_fn.saved_tensors)=}")
print(f"Saved tensors size torch.compile: {sum([t.numel() * t.element_size() for t in tout.grad_fn.saved_tensors if t is not None]) / 2**20:.2f} MiB")

jout = jf(x, fc_weight, proj_weight)
print(f"{len(jout.grad_fn.saved_tensors)=}")
print(f"Saved tensors size Thunder: {sum([t.numel() * t.element_size() for t in jout.grad_fn.saved_tensors if t is not None]) / 2**20:.2f} MiB")

Script output:

len(tout.grad_fn.saved_tensors)=6
Saved tensors size torch.compile: 660.00 MiB
len(jout.grad_fn.saved_tensors)=8
Saved tensors size Thunder: 1060.00 MiB

Current thunder.examine.make_trace_dot doesn't print out all output variables so it's difficult to read the graph, but, anyway, here is the joint forward and backward graph before rematerialization is applied: image

Created with a breakpoint before this line: https://github.com/Lightning-AI/lightning-thunder/blob/9c916d9df73f3920b51e5951303a76b25ab2d4d4/thunder/core/rematerialization.py#L609

Resolving this bug would help us resolve https://github.com/Lightning-AI/lightning-thunder/issues/246.

cc @riccardofelluga

riccardofelluga commented 1 month ago

Thanks @IvanYashchuk for writing down what we spoke yesterday and writing the repro snippet. I'll get a look on it since I am working on similar area

IvanYashchuk commented 1 month ago

Using thunder.jit(f, nv_enable_linear=True, nv_enable_matmul=True) helps Thunder save less tensors because it simplifies the task for rematerialization leaving only one producer -> one consumer pair. With this change rematerialization decides to save only two intermediates that are outputs of linear:

len(jout.grad_fn.saved_tensors)=5
Saved tensors size Thunder: 580.00 MiB
def computation(x, fc_weight, proj_weight):
  # x: "cuda:0 f16[1, 16384, 2560]"
  # fc_weight: "cuda:0 f16[10240, 2560]"
  # proj_weight: "cuda:0 f16[2560, 10240]"
  [t10, t14, t21] = nvFusion0(x, fc_weight, proj_weight)
    # t22 = prims.convert_element_type(x, dtypes.float32)  # t22: "cuda:0 f32[1, 16384, 2560]"
    # t23 = prims.exp(t22)  # t23: "cuda:0 f32[1, 16384, 2560]"
    # t27 = prims.add(t23, t22)  # t27: "cuda:0 f32[1, 16384, 2560]"
    # t30 = prims.exp(t27)  # t30: "cuda:0 f32[1, 16384, 2560]"
    # nx = prims.convert_element_type(t30, dtypes.float16)  # nx: "cuda:0 f16[1, 16384, 2560]"
    # t10 = prims.linear(nx, fc_weight, None)  # t10: "cuda:0 f16[1, 16384, 10240]"
    # t33 = prims.convert_element_type(t10, dtypes.float32)  # t33: "cuda:0 f32[1, 16384, 10240]"
    # t34 = prims.exp(t33)  # t34: "cuda:0 f32[1, 16384, 10240]"
    # t13 = prims.convert_element_type(t34, dtypes.float16)  # t13: "cuda:0 f16[1, 16384, 10240]"
    # t14 = prims.linear(t13, proj_weight, None)  # t14: "cuda:0 f16[1, 16384, 2560]"
    # t37 = prims.convert_element_type(t14, dtypes.float32)  # t37: "cuda:0 f32[1, 16384, 2560]"
    # t39 = prims.add(t37, t27)  # t39: "cuda:0 f32[1, 16384, 2560]"
    # t42 = prims.exp(t39)  # t42: "cuda:0 f32[1, 16384, 2560]"
    # t21 = prims.convert_element_type(t42, dtypes.float16)  # t21: "cuda:0 f16[1, 16384, 2560]"
  return {'output': t21, 'flat_args': [x, fc_weight, proj_weight], 'flat_output': (t21,)}, ((fc_weight, proj_weight, t10, t14, x), ())
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t44, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  fc_weight, proj_weight, t10, t14, x, = C0
  clear_mutable_collection(C0)
  del C0
  [t62, t72, t89] = nvFusion0(x, t14, t44, proj_weight, t10, fc_weight)
    # t22 = prims.convert_element_type(x, dtypes.float32)  # t22: "cuda:0 f32[1, 16384, 2560]"
    # t23 = prims.exp(t22)  # t23: "cuda:0 f32[1, 16384, 2560]"
    # t27 = prims.add(t23, t22)  # t27: "cuda:0 f32[1, 16384, 2560]"
    # t37 = prims.convert_element_type(t14, dtypes.float32)  # t37: "cuda:0 f32[1, 16384, 2560]"
    # t39 = prims.add(t37, t27)  # t39: "cuda:0 f32[1, 16384, 2560]"
    # t42 = prims.exp(t39)  # t42: "cuda:0 f32[1, 16384, 2560]"
    # t50 = prims.convert_element_type(t44, dtypes.float32)  # t50: "cuda:0 f32[1, 16384, 2560]"
    # t51 = prims.mul(t50, t42)  # t51: "cuda:0 f32[1, 16384, 2560]"
    # t52 = prims.convert_element_type(t51, dtypes.float16)  # t52: "cuda:0 f16[1, 16384, 2560]"
    # t56 = prims.reshape(t52, (16384, 2560))  # t56: "cuda:0 f16[16384, 2560]"
    # t57 = prims.matmul(t56, proj_weight)  # t57: "cuda:0 f16[16384, 10240]"
    # t33 = prims.convert_element_type(t10, dtypes.float32)  # t33: "cuda:0 f32[1, 16384, 10240]"
    # t58 = prims.reshape(t57, (1, 16384, 10240))  # t58: "cuda:0 f16[1, 16384, 10240]"
    # t34 = prims.exp(t33)  # t34: "cuda:0 f32[1, 16384, 10240]"
    # t63 = prims.convert_element_type(t58, dtypes.float32)  # t63: "cuda:0 f32[1, 16384, 10240]"
    # t64 = prims.mul(t63, t34)  # t64: "cuda:0 f32[1, 16384, 10240]"
    # t65 = prims.convert_element_type(t64, dtypes.float16)  # t65: "cuda:0 f16[1, 16384, 10240]"
    # t66 = prims.reshape(t65, (16384, 10240))  # t66: "cuda:0 f16[16384, 10240]"
    # t67 = prims.matmul(t66, fc_weight)  # t67: "cuda:0 f16[16384, 2560]"
    # t68 = prims.reshape(t67, (1, 16384, 2560))  # t68: "cuda:0 f16[1, 16384, 2560]"
    # t30 = prims.exp(t27)  # t30: "cuda:0 f32[1, 16384, 2560]"
    # t73 = prims.convert_element_type(t68, dtypes.float32)  # t73: "cuda:0 f32[1, 16384, 2560]"
    # t74 = prims.mul(t73, t30)  # t74: "cuda:0 f32[1, 16384, 2560]"
    # t78 = prims.add(t51, t74)  # t78: "cuda:0 f32[1, 16384, 2560]"
    # t84 = prims.mul(t78, t23)  # t84: "cuda:0 f32[1, 16384, 2560]"
    # nx = prims.convert_element_type(t30, dtypes.float16)  # nx: "cuda:0 f16[1, 16384, 2560]"
    # t13 = prims.convert_element_type(t34, dtypes.float16)  # t13: "cuda:0 f16[1, 16384, 10240]"
    # t88 = prims.add(t78, t84)  # t88: "cuda:0 f32[1, 16384, 2560]"
    # t71 = prims.reshape(nx, (16384, 2560))  # t71: "cuda:0 f16[16384, 2560]"
    # t70 = prims.transpose(t66, (1, 0))  # t70: "cuda:0 f16[10240, 16384]"
    # t61 = prims.reshape(t13, (16384, 10240))  # t61: "cuda:0 f16[16384, 10240]"
    # t60 = prims.transpose(t56, (1, 0))  # t60: "cuda:0 f16[2560, 16384]"
    # t89 = prims.convert_element_type(t88, dtypes.float16)  # t89: "cuda:0 f16[1, 16384, 2560]"
    # t72 = prims.matmul(t70, t71)  # t72: "cuda:0 f16[10240, 2560]"
    # t62 = prims.matmul(t60, t61)  # t62: "cuda:0 f16[2560, 10240]"
  del x, t14, t44, proj_weight, t10, fc_weight
  return (t89, t72, t62)