Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

Thunder seems to use way more memory when `litgpt.Config.parallel_residual=True` #1175

Open crcrpar opened 1 month ago

crcrpar commented 1 month ago

🐛 Bug

When input sequences get longer, Thunder seems to tend to use more memory than eager and torch.compile.

Let's take litgpt's stablecode-completion-alpha-3b as an example whose sequence length (Config.block_size) is 16384.

With the following table, Thunder's memory consumption can be more prone to sequence length

sequence length Thunder Torch Compile Diff
16384 77.02 71.73 5.29
12288 61.9 57.94 3.96
8192 46.85 44.2 2.65
4096 31.78 30.46 1.32
2048 24.28 23.63 0.65

To Reproduce

Apply a diff like this and run commands like

python thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --warmup_iters 0 --max_iters 3 --compile eager --dump_memory_snapshot false --block_size 2048
@@ -227,6 +269,7 @@ class Benchmark_litGPT:
         fsdp_bucket_params: float | None = None,
         checkpoint_activations: bool = False,
         n_layers: int | None = None,
+        block_size: int | None = None,
         profiler_start: int = 15,
         profiler_stop: int = 15,
         skip_data_sync: bool = False,
@@ -360,6 +403,8 @@ class Benchmark_litGPT:

         if n_layers is not None:
             self.config.n_layer = n_layers
+        if block_size is not None:
+            self.config.block_size = block_size

         # Initialize the model
         t0 = time.perf_counter()

Code sample

Expected behavior

Environment

pjnl-20240919

nvMelissa commented 1 month ago

Add to Q4 planning

crcrpar commented 1 month ago

The cause seems to be basically parallel_residual=True as in https://github.com/Lightning-AI/lightning-thunder/issues/246#issuecomment-2302121789

crcrpar commented 1 month ago

Script to run litgpt.model.Block with the config of "stablecode-completion-alpha-3b" whose parallel_residual by default is True.

import argparse
import gc

import torch
from litgpt import Config, GPT
from litgpt.model import Block

import thunder

def init_model(config: Config, compiler: str, device: torch.device, dtype: torch.dtype) -> GPT:
    model = Block(config, 0).to(device=device, dtype=dtype)
    print(model)
    match compiler:
        case "eager":
            return model
        case "thunder":
            return thunder.jit(model)
        case "torch.compile":
            return torch.compile(model)

def print_memory_stats(header):
    stats = torch.cuda.memory_stats()
    print("{}| current active: {:.3f}, allocated peak: {:.3f}, current allocated: {:.3f}".format(
        header,
        stats["active_bytes.all.current"] / 1e9,
        stats["allocated_bytes.all.peak"] / 1e9,
        stats["allocated_bytes.all.current"] / 1e9,
    ))

def get_batch(
    config: Config,
    device: torch.device,
    dtype: torch.dtype,
    args: argparse.Namespace,
) -> tuple[torch.Tensor, torch.Tensor]:
    with device:
        return (
            torch.randn(size=(args.batch_size, config.block_size, config.n_embd), dtype=dtype, requires_grad=True),
            torch.randn(size=(config.block_size, config.rope_n_elem), dtype=dtype),
            torch.randn(size=(config.block_size, config.rope_n_elem), dtype=dtype),
        )

def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--n-layer",
        "-N",
        type=int,
        default=8,
        help="number of layers to fit the model on RTX Ada 6000. Default value is 32.",
    )
    parser.add_argument("--batch-size", "-B", type=int, default=1, help="Batch size.")
    parser.add_argument("--block-size", "-S", type=int, default=16384, help="Sequence length.")
    parser.add_argument(
        "--compiler",
        "-C",
        type=str,
        default="eager",
        help="Deep Learning Compiler to use.",
        choices=("eager", "torch.compile", "thunder"),
    )
    parser.add_argument("--n-iter", "-I", type=int, default=3, help="Number of iterations")
    parser.add_argument(
        "--model-name",
        "-M",
        type=str,
        default="stablecode-completion-alpha-3b",
        choices=(
            "stablecode-completion-alpha-3b",
            "Llama-3-8B",
        ),
    )
    parser.add_argument("--mem-snapshot", action="store_true", default=False)
    parser.add_argument("--dtype", default="bfloat16", choices=("float32", "bfloat16"))
    parser.add_argument("--dump-traces", action="store_true", default=False)
    parser.add_argument("--disable-parallel-residual", action="store_true", default=False)
    args = parser.parse_args()

    if args.mem_snapshot:
        torch.cuda.memory._record_memory_history()

    print("*" * 80)
    print(f"* {args.n_layer=}, {args.block_size=}, {args.batch_size=}, {args.compiler=}, {args.dtype=}")
    print("*" * 80)

    device = torch.device("cuda")
    dtype = getattr(torch, args.dtype)

    model_name = args.model_name
    config = Config.from_name(model_name)
    config.n_layer = args.n_layer
    config.block_size = args.block_size
    config.parallel_residual = not args.disable_parallel_residual

    print(f"#####\n{config}")

    model = init_model(config, args.compiler, device, dtype)
    print_memory_stats("model")
    optimizer = torch.optim.AdamW(model.parameters())

    print_memory_stats("model and data")
    for i in range(args.n_iter):
        optimizer.zero_grad()
        print(f"  Iter: {i + 1}")

        x, cos, sin = get_batch(config, device, dtype, args)
        out = model(x, cos, sin)
        print_memory_stats("    model, data, and forward results")

        loss = out.mean()
        loss.backward()
        print_memory_stats("    model, data, forward, and backward results")

        optimizer.step()
        print_memory_stats("    model, data, forward, backward, and optimizer results")

    if args.compiler == "thunder":
        compile_data = thunder.compile_data(model)
        print(f"Used executors: {[e.name for e in compile_data.executors_list]}")

    if args.compiler == "thunder" and args.dump_traces:
        from thunder.examine.memory_caculation import get_alloc_memory
        extrace = thunder.last_traces(model)[-1]
        file_name = f"trace_block_of_{model_name}_{args.compiler}_{args.n_layer}.py"
        if not config.parallel_residual:
            file_name = f"trace_block_of_{model_name}_{args.compiler}_{args.n_layer}_no_parallel_residual.py"
        with open(file_name, "w") as f:
            f.write(f"{extrace}\n")
            f.write(f"{thunder.last_backward_traces(model)[-1]}\n")

    del x, model, optimizer
    gc.collect()
    print_memory_stats("After del'ing data, model, and optimizer")
    file_name = f"block_of_{model_name}_{args.compiler}_{args.n_layer}.pickle"
    if args.mem_snapshot:
        torch.cuda.memory._dump_snapshot(file_name)
        print(f"Saving snapshot into {file_name}...")
        torch.cuda.memory._record_memory_history(enabled=None)

if __name__ == "__main__":
    main()

trace with parallel_residual=True.

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight):
  # x: "cuda:0 bf16[1, 16384, 2560]"
  # cos: "cuda:0 bf16[16384, 20]"
  # sin: "cuda:0 bf16[16384, 20]"
  # t_attn_attn_bias: "cuda:0 bf16[7680]"
  # t_attn_attn_weight: "cuda:0 bf16[7680, 2560]"
  # t_attn_proj_bias: "cuda:0 bf16[2560]"
  # t_attn_proj_weight: "cuda:0 bf16[2560, 2560]"
  # t_mlp_fc_bias: "cuda:0 bf16[10240]"
  # t_mlp_fc_weight: "cuda:0 bf16[10240, 2560]"
  # t_mlp_proj_bias: "cuda:0 bf16[2560]"
  # t_mlp_proj_weight: "cuda:0 bf16[2560, 10240]"
  # t_norm_1_bias: "cuda:0 bf16[2560]"
  # t_norm_1_weight: "cuda:0 bf16[2560]"
  # t_norm_2_bias: "cuda:0 bf16[2560]"
  # t_norm_2_weight: "cuda:0 bf16[2560]"
  [t4, t8, t20, t118] = nvFusion0(x, t_norm_1_weight, t_norm_1_bias, t_norm_2_weight, t_norm_2_bias)
    # t0 = prims.convert_element_type(x, dtypes.float32)  # t0: "cuda:0 f32[1, 16384, 2560]"
    # (t3, t4) = prims.var_mean(t0, (2,), correction=0)
    # t5 = prims.broadcast_in_dim(t3, [1, 16384, 1], [0, 1])  # t5: "cuda:0 f32[1, 16384, 1]"
    # t6 = prims.broadcast_in_dim(t4, [1, 16384, 1], [0, 1])  # t6: "cuda:0 f32[1, 16384, 1]"
    # t7 = prims.add(t5, 1e-05)  # t7: "cuda:0 f32[1, 16384, 1]"
    # t9 = prims.broadcast_in_dim(t6, (1, 16384, 2560), (0, 1, 2))  # t9: "cuda:0 f32[1, 16384, 2560]"
    # t8 = prims.rsqrt(t7)  # t8: "cuda:0 f32[1, 16384, 1]"
    # t11 = prims.sub(t0, t9)  # t11: "cuda:0 f32[1, 16384, 2560]"
    # t12 = prims.broadcast_in_dim(t8, (1, 16384, 2560), (0, 1, 2))  # t12: "cuda:0 f32[1, 16384, 2560]"
    # t13 = prims.mul(t11, t12)  # t13: "cuda:0 f32[1, 16384, 2560]"
    # t14 = prims.broadcast_in_dim(t_norm_1_weight, (1, 16384, 2560), (2,))  # t14: "cuda:0 bf16[1, 16384, 2560]"
    # t15 = prims.convert_element_type(t14, dtypes.float32)  # t15: "cuda:0 f32[1, 16384, 2560]"
    # t16 = prims.mul(t13, t15)  # t16: "cuda:0 f32[1, 16384, 2560]"
    # t17 = prims.broadcast_in_dim(t_norm_1_bias, (1, 16384, 2560), (2,))  # t17: "cuda:0 bf16[1, 16384, 2560]"
    # t18 = prims.convert_element_type(t17, dtypes.float32)  # t18: "cuda:0 f32[1, 16384, 2560]"
    # t19 = prims.add(t16, t18)  # t19: "cuda:0 f32[1, 16384, 2560]"
    # t20 = prims.convert_element_type(t19, dtypes.bfloat16)  # t20: "cuda:0 bf16[1, 16384, 2560]"
    # t112 = prims.broadcast_in_dim(t_norm_2_weight, (1, 16384, 2560), (2,))  # t112: "cuda:0 bf16[1, 16384, 2560]"
    # t113 = prims.convert_element_type(t112, dtypes.float32)  # t113: "cuda:0 f32[1, 16384, 2560]"
    # t114 = prims.mul(t13, t113)  # t114: "cuda:0 f32[1, 16384, 2560]"
    # t115 = prims.broadcast_in_dim(t_norm_2_bias, (1, 16384, 2560), (2,))  # t115: "cuda:0 bf16[1, 16384, 2560]"
    # t116 = prims.convert_element_type(t115, dtypes.float32)  # t116: "cuda:0 f32[1, 16384, 2560]"
    # t117 = prims.add(t114, t116)  # t117: "cuda:0 f32[1, 16384, 2560]"
    # t118 = prims.convert_element_type(t117, dtypes.bfloat16)  # t118: "cuda:0 bf16[1, 16384, 2560]"
  t21 = torch.nn.functional.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
    # t21 = ltorch.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
      # t21 = prims.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
  t119 = torch.nn.functional.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias)  # t119: "cuda:0 bf16[1, 16384, 10240]"
    # t119 = ltorch.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias)  # t119: "cuda:0 bf16[1, 16384, 10240]"
      # t119 = prims.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias)  # t119: "cuda:0 bf16[1, 16384, 10240]"
  [t39, t84, t87, t121, t130, t135] = TorchCompile0(t21, cos, sin, t119)
    # t22 = prims.reshape(t21, (1, 16384, 32, 3, 80))  # t22: "cuda:0 bf16[1, 16384, 32, 3, 80]"
    # t23 = prims.transpose(t22, (0, 2, 3, 1, 4))  # t23: "cuda:0 bf16[1, 32, 3, 16384, 80]"
    # (t24, t25, t26) = ltorch.split(t23, (1, 1, 1), 2)
      # t24 = prims.slice_prim(t23, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 80], [1, 1, 1, 1, 1])  # t24: "cuda:0 bf16[1, 32, 1, 16384, 80]"
      # t25 = prims.slice_prim(t23, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 80], [1, 1, 1, 1, 1])  # t25: "cuda:0 bf16[1, 32, 1, 16384, 80]"
      # t26 = prims.slice_prim(t23, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 80], [1, 1, 1, 1, 1])  # t26: "cuda:0 bf16[1, 32, 1, 16384, 80]"
    # t27 = prims.reshape(t24, (1, 32, 16384, 80))  # t27: "cuda:0 bf16[1, 32, 16384, 80]"
    # t33 = prims.reshape(t25, (1, 32, 16384, 80))  # t33: "cuda:0 bf16[1, 32, 16384, 80]"
    # t39 = prims.reshape(t26, (1, 32, 16384, 80))  # t39: "cuda:0 bf16[1, 32, 16384, 80]"
    # t40 = prims.slice_prim(t27, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1])  # t40: "cuda:0 bf16[1, 32, 16384, 20]"
    # t41 = prims.slice_prim(t40, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1])  # t41: "cuda:0 bf16[1, 32, 16384, 10]"
    # t42 = prims.slice_prim(t40, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1])  # t42: "cuda:0 bf16[1, 32, 16384, 10]"
    # t43 = prims.convert_element_type(t42, dtypes.float32)  # t43: "cuda:0 f32[1, 32, 16384, 10]"
    # t44 = prims.neg(t43)  # t44: "cuda:0 f32[1, 32, 16384, 10]"
    # t45 = prims.convert_element_type(t44, dtypes.bfloat16)  # t45: "cuda:0 bf16[1, 32, 16384, 10]"
    # t46 = prims.cat([t45, t41], -1)  # t46: "cuda:0 bf16[1, 32, 16384, 20]"
    # t47 = prims.broadcast_in_dim(cos, (1, 32, 16384, 20), (2, 3))  # t47: "cuda:0 bf16[1, 32, 16384, 20]"
    # t48 = prims.convert_element_type(t40, dtypes.float32)  # t48: "cuda:0 f32[1, 32, 16384, 20]"
    # t49 = prims.convert_element_type(t47, dtypes.float32)  # t49: "cuda:0 f32[1, 32, 16384, 20]"
    # t50 = ltorch.mul(t48, t49)  # t50: "cuda:0 f32[1, 32, 16384, 20]"
      # t50 = prims.mul(t48, t49)  # t50: "cuda:0 f32[1, 32, 16384, 20]"
    # t51 = prims.convert_element_type(t50, dtypes.bfloat16)  # t51: "cuda:0 bf16[1, 32, 16384, 20]"
    # t52 = prims.broadcast_in_dim(sin, (1, 32, 16384, 20), (2, 3))  # t52: "cuda:0 bf16[1, 32, 16384, 20]"
    # t53 = prims.convert_element_type(t46, dtypes.float32)  # t53: "cuda:0 f32[1, 32, 16384, 20]"
    # t54 = prims.convert_element_type(t52, dtypes.float32)  # t54: "cuda:0 f32[1, 32, 16384, 20]"
    # t55 = ltorch.mul(t53, t54)  # t55: "cuda:0 f32[1, 32, 16384, 20]"
      # t55 = prims.mul(t53, t54)  # t55: "cuda:0 f32[1, 32, 16384, 20]"
    # t56 = prims.convert_element_type(t55, dtypes.bfloat16)  # t56: "cuda:0 bf16[1, 32, 16384, 20]"
    # t59 = ltorch.add(t50, t55, alpha=None)  # t59: "cuda:0 f32[1, 32, 16384, 20]"
      # t59 = prims.add(t50, t55)  # t59: "cuda:0 f32[1, 32, 16384, 20]"
    # t60 = prims.convert_element_type(t59, dtypes.bfloat16)  # t60: "cuda:0 bf16[1, 32, 16384, 20]"
    # t61 = prims.slice_prim(t33, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1])  # t61: "cuda:0 bf16[1, 32, 16384, 20]"
    # t62 = prims.slice_prim(t61, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1])  # t62: "cuda:0 bf16[1, 32, 16384, 10]"
    # t63 = prims.slice_prim(t61, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1])  # t63: "cuda:0 bf16[1, 32, 16384, 10]"
    # t64 = prims.convert_element_type(t63, dtypes.float32)  # t64: "cuda:0 f32[1, 32, 16384, 10]"
    # t65 = prims.neg(t64)  # t65: "cuda:0 f32[1, 32, 16384, 10]"
    # t66 = prims.convert_element_type(t65, dtypes.bfloat16)  # t66: "cuda:0 bf16[1, 32, 16384, 10]"
    # t68 = prims.cat([t66, t62], -1)  # t68: "cuda:0 bf16[1, 32, 16384, 20]"
    # t70 = prims.convert_element_type(t61, dtypes.float32)  # t70: "cuda:0 f32[1, 32, 16384, 20]"
    # t72 = ltorch.mul(t70, t49)  # t72: "cuda:0 f32[1, 32, 16384, 20]"
      # t72 = prims.mul(t70, t49)  # t72: "cuda:0 f32[1, 32, 16384, 20]"
    # t73 = prims.convert_element_type(t72, dtypes.bfloat16)  # t73: "cuda:0 bf16[1, 32, 16384, 20]"
    # t75 = prims.convert_element_type(t68, dtypes.float32)  # t75: "cuda:0 f32[1, 32, 16384, 20]"
    # t77 = ltorch.mul(t75, t54)  # t77: "cuda:0 f32[1, 32, 16384, 20]"
      # t77 = prims.mul(t75, t54)  # t77: "cuda:0 f32[1, 32, 16384, 20]"
    # t78 = prims.convert_element_type(t77, dtypes.bfloat16)  # t78: "cuda:0 bf16[1, 32, 16384, 20]"
    # t81 = ltorch.add(t72, t77, alpha=None)  # t81: "cuda:0 f32[1, 32, 16384, 20]"
      # t81 = prims.add(t72, t77)  # t81: "cuda:0 f32[1, 32, 16384, 20]"
    # t82 = prims.convert_element_type(t81, dtypes.bfloat16)  # t82: "cuda:0 bf16[1, 32, 16384, 20]"
    # t83 = prims.slice_prim(t27, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1])  # t83: "cuda:0 bf16[1, 32, 16384, 60]"
    # t84 = prims.cat([t60, t83], -1)  # t84: "cuda:0 bf16[1, 32, 16384, 80]"
    # t85 = prims.slice_prim(t33, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1])  # t85: "cuda:0 bf16[1, 32, 16384, 60]"
    # t87 = prims.cat([t82, t85], -1)  # t87: "cuda:0 bf16[1, 32, 16384, 80]"
    # t120 = prims.convert_element_type(t119, dtypes.float32)  # t120: "cuda:0 f32[1, 16384, 10240]"
    # t121 = ltorch.true_divide(t120, 1.4142135623730951)  # t121: "cuda:0 f32[1, 16384, 10240]"
      # t121 = prims.div(t120, 1.4142135623730951)  # t121: "cuda:0 f32[1, 16384, 10240]"
    # t122 = prims.convert_element_type(t121, dtypes.bfloat16)  # t122: "cuda:0 bf16[1, 16384, 10240]"
    # t124 = prims.erf(t121)  # t124: "cuda:0 f32[1, 16384, 10240]"
    # t125 = prims.convert_element_type(t124, dtypes.bfloat16)  # t125: "cuda:0 bf16[1, 16384, 10240]"
    # t127 = ltorch.mul(0.5, t124)  # t127: "cuda:0 f32[1, 16384, 10240]"
      # t127 = prims.mul(0.5, t124)  # t127: "cuda:0 f32[1, 16384, 10240]"
    # t128 = prims.convert_element_type(t127, dtypes.bfloat16)  # t128: "cuda:0 bf16[1, 16384, 10240]"
    # t130 = ltorch.add(0.5, t127, alpha=None)  # t130: "cuda:0 f32[1, 16384, 10240]"
      # t130 = prims.add(0.5, t127)  # t130: "cuda:0 f32[1, 16384, 10240]"
    # t131 = prims.convert_element_type(t130, dtypes.bfloat16)  # t131: "cuda:0 bf16[1, 16384, 10240]"
    # t134 = ltorch.mul(t120, t130)  # t134: "cuda:0 f32[1, 16384, 10240]"
      # t134 = prims.mul(t120, t130)  # t134: "cuda:0 f32[1, 16384, 10240]"
    # t135 = prims.convert_element_type(t134, dtypes.bfloat16)  # t135: "cuda:0 bf16[1, 16384, 10240]"
  del t21
  (t88, t89, t90, t91) = cudnn_sdpa_fwd(t84, t87, t39, None, 0.0, True, scale=0.11180339887498948)
  t136 = torch.nn.functional.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias)  # t136: "cuda:0 bf16[1, 16384, 2560]"
    # t136 = ltorch.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias)  # t136: "cuda:0 bf16[1, 16384, 2560]"
      # t136 = prims.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias)  # t136: "cuda:0 bf16[1, 16384, 2560]"
  [t93] = nvFusion1(t88)
    # t92 = prims.transpose(t88, (0, 2, 1, 3))  # t92: "cuda:0 bf16[1, 16384, 32, 80]"
    # t93 = prims.reshape(t92, (1, 16384, 2560))  # t93: "cuda:0 bf16[1, 16384, 2560]"
  t94 = torch.nn.functional.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
    # t94 = ltorch.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
      # t94 = prims.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
  [t144] = nvFusion2(t136, t94, x)
    # t137 = prims.convert_element_type(t136, dtypes.float32)  # t137: "cuda:0 f32[1, 16384, 2560]"
    # t138 = prims.convert_element_type(t94, dtypes.float32)  # t138: "cuda:0 f32[1, 16384, 2560]"
    # t139 = prims.add(t137, t138)  # t139: "cuda:0 f32[1, 16384, 2560]"
    # t142 = prims.convert_element_type(x, dtypes.float32)  # t142: "cuda:0 f32[1, 16384, 2560]"
    # t143 = prims.add(t139, t142)  # t143: "cuda:0 f32[1, 16384, 2560]"
    # t144 = prims.convert_element_type(t143, dtypes.bfloat16)  # t144: "cuda:0 bf16[1, 16384, 2560]"
  del t136, t94
  return {'output': t144, 'flat_args': [x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight], 'flat_output': (t144,)}, ((cos, sin, t118, t119, t121, t130, t135, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x), ())

Traces without parallel_residual

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight):
  # x: "cuda:0 bf16[1, 16384, 2560]"
  # cos: "cuda:0 bf16[16384, 20]"
  # sin: "cuda:0 bf16[16384, 20]"
  # t_attn_attn_bias: "cuda:0 bf16[7680]"
  # t_attn_attn_weight: "cuda:0 bf16[7680, 2560]"
  # t_attn_proj_bias: "cuda:0 bf16[2560]"
  # t_attn_proj_weight: "cuda:0 bf16[2560, 2560]"
  # t_mlp_fc_bias: "cuda:0 bf16[10240]"
  # t_mlp_fc_weight: "cuda:0 bf16[10240, 2560]"
  # t_mlp_proj_bias: "cuda:0 bf16[2560]"
  # t_mlp_proj_weight: "cuda:0 bf16[2560, 10240]"
  # t_norm_1_bias: "cuda:0 bf16[2560]"
  # t_norm_1_weight: "cuda:0 bf16[2560]"
  # t_norm_2_bias: "cuda:0 bf16[2560]"
  # t_norm_2_weight: "cuda:0 bf16[2560]"
  [t4, t8, t20] = nvFusion0(x, t_norm_1_weight, t_norm_1_bias)
    # t0 = prims.convert_element_type(x, dtypes.float32)  # t0: "cuda:0 f32[1, 16384, 2560]"
    # (t3, t4) = prims.var_mean(t0, (2,), correction=0)
    # t5 = prims.broadcast_in_dim(t3, [1, 16384, 1], [0, 1])  # t5: "cuda:0 f32[1, 16384, 1]"
    # t6 = prims.broadcast_in_dim(t4, [1, 16384, 1], [0, 1])  # t6: "cuda:0 f32[1, 16384, 1]"
    # t7 = prims.add(t5, 1e-05)  # t7: "cuda:0 f32[1, 16384, 1]"
    # t9 = prims.broadcast_in_dim(t6, (1, 16384, 2560), (0, 1, 2))  # t9: "cuda:0 f32[1, 16384, 2560]"
    # t8 = prims.rsqrt(t7)  # t8: "cuda:0 f32[1, 16384, 1]"
    # t11 = prims.sub(t0, t9)  # t11: "cuda:0 f32[1, 16384, 2560]"
    # t12 = prims.broadcast_in_dim(t8, (1, 16384, 2560), (0, 1, 2))  # t12: "cuda:0 f32[1, 16384, 2560]"
    # t13 = prims.mul(t11, t12)  # t13: "cuda:0 f32[1, 16384, 2560]"
    # t14 = prims.broadcast_in_dim(t_norm_1_weight, (1, 16384, 2560), (2,))  # t14: "cuda:0 bf16[1, 16384, 2560]"
    # t15 = prims.convert_element_type(t14, dtypes.float32)  # t15: "cuda:0 f32[1, 16384, 2560]"
    # t16 = prims.mul(t13, t15)  # t16: "cuda:0 f32[1, 16384, 2560]"
    # t17 = prims.broadcast_in_dim(t_norm_1_bias, (1, 16384, 2560), (2,))  # t17: "cuda:0 bf16[1, 16384, 2560]"
    # t18 = prims.convert_element_type(t17, dtypes.float32)  # t18: "cuda:0 f32[1, 16384, 2560]"
    # t19 = prims.add(t16, t18)  # t19: "cuda:0 f32[1, 16384, 2560]"
    # t20 = prims.convert_element_type(t19, dtypes.bfloat16)  # t20: "cuda:0 bf16[1, 16384, 2560]"
  t21 = torch.nn.functional.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
    # t21 = ltorch.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
      # t21 = prims.linear(t20, t_attn_attn_weight, t_attn_attn_bias)  # t21: "cuda:0 bf16[1, 16384, 7680]"
  [t39, t84, t87] = TorchCompile0(t21, cos, sin)
    # t22 = prims.reshape(t21, (1, 16384, 32, 3, 80))  # t22: "cuda:0 bf16[1, 16384, 32, 3, 80]"
    # t23 = prims.transpose(t22, (0, 2, 3, 1, 4))  # t23: "cuda:0 bf16[1, 32, 3, 16384, 80]"
    # (t24, t25, t26) = ltorch.split(t23, (1, 1, 1), 2)
      # t24 = prims.slice_prim(t23, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 80], [1, 1, 1, 1, 1])  # t24: "cuda:0 bf16[1, 32, 1, 16384, 80]"
      # t25 = prims.slice_prim(t23, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 80], [1, 1, 1, 1, 1])  # t25: "cuda:0 bf16[1, 32, 1, 16384, 80]"
      # t26 = prims.slice_prim(t23, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 80], [1, 1, 1, 1, 1])  # t26: "cuda:0 bf16[1, 32, 1, 16384, 80]"
    # t27 = prims.reshape(t24, (1, 32, 16384, 80))  # t27: "cuda:0 bf16[1, 32, 16384, 80]"
    # t33 = prims.reshape(t25, (1, 32, 16384, 80))  # t33: "cuda:0 bf16[1, 32, 16384, 80]"
    # t39 = prims.reshape(t26, (1, 32, 16384, 80))  # t39: "cuda:0 bf16[1, 32, 16384, 80]"
    # t40 = prims.slice_prim(t27, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1])  # t40: "cuda:0 bf16[1, 32, 16384, 20]"
    # t41 = prims.slice_prim(t40, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1])  # t41: "cuda:0 bf16[1, 32, 16384, 10]"
    # t42 = prims.slice_prim(t40, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1])  # t42: "cuda:0 bf16[1, 32, 16384, 10]"
    # t43 = prims.convert_element_type(t42, dtypes.float32)  # t43: "cuda:0 f32[1, 32, 16384, 10]"
    # t44 = prims.neg(t43)  # t44: "cuda:0 f32[1, 32, 16384, 10]"
    # t45 = prims.convert_element_type(t44, dtypes.bfloat16)  # t45: "cuda:0 bf16[1, 32, 16384, 10]"
    # t46 = prims.cat([t45, t41], -1)  # t46: "cuda:0 bf16[1, 32, 16384, 20]"
    # t47 = prims.broadcast_in_dim(cos, (1, 32, 16384, 20), (2, 3))  # t47: "cuda:0 bf16[1, 32, 16384, 20]"
    # t48 = prims.convert_element_type(t40, dtypes.float32)  # t48: "cuda:0 f32[1, 32, 16384, 20]"
    # t49 = prims.convert_element_type(t47, dtypes.float32)  # t49: "cuda:0 f32[1, 32, 16384, 20]"
    # t50 = ltorch.mul(t48, t49)  # t50: "cuda:0 f32[1, 32, 16384, 20]"
      # t50 = prims.mul(t48, t49)  # t50: "cuda:0 f32[1, 32, 16384, 20]"
    # t51 = prims.convert_element_type(t50, dtypes.bfloat16)  # t51: "cuda:0 bf16[1, 32, 16384, 20]"
    # t52 = prims.broadcast_in_dim(sin, (1, 32, 16384, 20), (2, 3))  # t52: "cuda:0 bf16[1, 32, 16384, 20]"
    # t53 = prims.convert_element_type(t46, dtypes.float32)  # t53: "cuda:0 f32[1, 32, 16384, 20]"
    # t54 = prims.convert_element_type(t52, dtypes.float32)  # t54: "cuda:0 f32[1, 32, 16384, 20]"
    # t55 = ltorch.mul(t53, t54)  # t55: "cuda:0 f32[1, 32, 16384, 20]"
      # t55 = prims.mul(t53, t54)  # t55: "cuda:0 f32[1, 32, 16384, 20]"
    # t56 = prims.convert_element_type(t55, dtypes.bfloat16)  # t56: "cuda:0 bf16[1, 32, 16384, 20]"
    # t59 = ltorch.add(t50, t55, alpha=None)  # t59: "cuda:0 f32[1, 32, 16384, 20]"
      # t59 = prims.add(t50, t55)  # t59: "cuda:0 f32[1, 32, 16384, 20]"
    # t60 = prims.convert_element_type(t59, dtypes.bfloat16)  # t60: "cuda:0 bf16[1, 32, 16384, 20]"
    # t61 = prims.slice_prim(t33, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1])  # t61: "cuda:0 bf16[1, 32, 16384, 20]"
    # t62 = prims.slice_prim(t61, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1])  # t62: "cuda:0 bf16[1, 32, 16384, 10]"
    # t63 = prims.slice_prim(t61, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1])  # t63: "cuda:0 bf16[1, 32, 16384, 10]"
    # t64 = prims.convert_element_type(t63, dtypes.float32)  # t64: "cuda:0 f32[1, 32, 16384, 10]"
    # t65 = prims.neg(t64)  # t65: "cuda:0 f32[1, 32, 16384, 10]"
    # t66 = prims.convert_element_type(t65, dtypes.bfloat16)  # t66: "cuda:0 bf16[1, 32, 16384, 10]"
    # t68 = prims.cat([t66, t62], -1)  # t68: "cuda:0 bf16[1, 32, 16384, 20]"
    # t70 = prims.convert_element_type(t61, dtypes.float32)  # t70: "cuda:0 f32[1, 32, 16384, 20]"
    # t72 = ltorch.mul(t70, t49)  # t72: "cuda:0 f32[1, 32, 16384, 20]"
      # t72 = prims.mul(t70, t49)  # t72: "cuda:0 f32[1, 32, 16384, 20]"
    # t73 = prims.convert_element_type(t72, dtypes.bfloat16)  # t73: "cuda:0 bf16[1, 32, 16384, 20]"
    # t75 = prims.convert_element_type(t68, dtypes.float32)  # t75: "cuda:0 f32[1, 32, 16384, 20]"
    # t77 = ltorch.mul(t75, t54)  # t77: "cuda:0 f32[1, 32, 16384, 20]"
      # t77 = prims.mul(t75, t54)  # t77: "cuda:0 f32[1, 32, 16384, 20]"
    # t78 = prims.convert_element_type(t77, dtypes.bfloat16)  # t78: "cuda:0 bf16[1, 32, 16384, 20]"
    # t81 = ltorch.add(t72, t77, alpha=None)  # t81: "cuda:0 f32[1, 32, 16384, 20]"
      # t81 = prims.add(t72, t77)  # t81: "cuda:0 f32[1, 32, 16384, 20]"
    # t82 = prims.convert_element_type(t81, dtypes.bfloat16)  # t82: "cuda:0 bf16[1, 32, 16384, 20]"
    # t83 = prims.slice_prim(t27, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1])  # t83: "cuda:0 bf16[1, 32, 16384, 60]"
    # t84 = prims.cat([t60, t83], -1)  # t84: "cuda:0 bf16[1, 32, 16384, 80]"
    # t85 = prims.slice_prim(t33, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1])  # t85: "cuda:0 bf16[1, 32, 16384, 60]"
    # t87 = prims.cat([t82, t85], -1)  # t87: "cuda:0 bf16[1, 32, 16384, 80]"
  del t21
  (t88, t89, t90, t91) = cudnn_sdpa_fwd(t84, t87, t39, None, 0.0, True, scale=0.11180339887498948)
  [t93] = nvFusion1(t88)
    # t92 = prims.transpose(t88, (0, 2, 1, 3))  # t92: "cuda:0 bf16[1, 16384, 32, 80]"
    # t93 = prims.reshape(t92, (1, 16384, 2560))  # t93: "cuda:0 bf16[1, 16384, 2560]"
  t94 = torch.nn.functional.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
    # t94 = ltorch.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
      # t94 = prims.linear(t93, t_attn_proj_weight, t_attn_proj_bias)  # t94: "cuda:0 bf16[1, 16384, 2560]"
  [t98, t105, t110, t122] = nvFusion2(t94, x, t_norm_2_weight, t_norm_2_bias)
    # t95 = prims.convert_element_type(t94, dtypes.float32)  # t95: "cuda:0 f32[1, 16384, 2560]"
    # t96 = prims.convert_element_type(x, dtypes.float32)  # t96: "cuda:0 f32[1, 16384, 2560]"
    # t97 = prims.add(t95, t96)  # t97: "cuda:0 f32[1, 16384, 2560]"
    # t98 = prims.convert_element_type(t97, dtypes.bfloat16)  # t98: "cuda:0 bf16[1, 16384, 2560]"
    # (t104, t105) = prims.var_mean(t97, (2,), correction=0)
    # t106 = prims.broadcast_in_dim(t104, [1, 16384, 1], [0, 1])  # t106: "cuda:0 f32[1, 16384, 1]"
    # t107 = prims.broadcast_in_dim(t105, [1, 16384, 1], [0, 1])  # t107: "cuda:0 f32[1, 16384, 1]"
    # t109 = prims.add(t106, 1e-05)  # t109: "cuda:0 f32[1, 16384, 1]"
    # t111 = prims.broadcast_in_dim(t107, (1, 16384, 2560), (0, 1, 2))  # t111: "cuda:0 f32[1, 16384, 2560]"
    # t110 = prims.rsqrt(t109)  # t110: "cuda:0 f32[1, 16384, 1]"
    # t113 = prims.sub(t97, t111)  # t113: "cuda:0 f32[1, 16384, 2560]"
    # t114 = prims.broadcast_in_dim(t110, (1, 16384, 2560), (0, 1, 2))  # t114: "cuda:0 f32[1, 16384, 2560]"
    # t115 = prims.mul(t113, t114)  # t115: "cuda:0 f32[1, 16384, 2560]"
    # t116 = prims.broadcast_in_dim(t_norm_2_weight, (1, 16384, 2560), (2,))  # t116: "cuda:0 bf16[1, 16384, 2560]"
    # t117 = prims.convert_element_type(t116, dtypes.float32)  # t117: "cuda:0 f32[1, 16384, 2560]"
    # t118 = prims.mul(t115, t117)  # t118: "cuda:0 f32[1, 16384, 2560]"
    # t119 = prims.broadcast_in_dim(t_norm_2_bias, (1, 16384, 2560), (2,))  # t119: "cuda:0 bf16[1, 16384, 2560]"
    # t120 = prims.convert_element_type(t119, dtypes.float32)  # t120: "cuda:0 f32[1, 16384, 2560]"
    # t121 = prims.add(t118, t120)  # t121: "cuda:0 f32[1, 16384, 2560]"
    # t122 = prims.convert_element_type(t121, dtypes.bfloat16)  # t122: "cuda:0 bf16[1, 16384, 2560]"
  t123 = torch.nn.functional.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias)  # t123: "cuda:0 bf16[1, 16384, 10240]"
    # t123 = ltorch.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias)  # t123: "cuda:0 bf16[1, 16384, 10240]"
      # t123 = prims.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias)  # t123: "cuda:0 bf16[1, 16384, 10240]"
  [t139] = nvFusion3(t123)
    # t124 = prims.convert_element_type(t123, dtypes.float32)  # t124: "cuda:0 f32[1, 16384, 10240]"
    # t125 = prims.div(t124, 1.4142135623730951)  # t125: "cuda:0 f32[1, 16384, 10240]"
    # t128 = prims.erf(t125)  # t128: "cuda:0 f32[1, 16384, 10240]"
    # t131 = prims.mul(0.5, t128)  # t131: "cuda:0 f32[1, 16384, 10240]"
    # t134 = prims.add(0.5, t131)  # t134: "cuda:0 f32[1, 16384, 10240]"
    # t138 = prims.mul(t124, t134)  # t138: "cuda:0 f32[1, 16384, 10240]"
    # t139 = prims.convert_element_type(t138, dtypes.bfloat16)  # t139: "cuda:0 bf16[1, 16384, 10240]"
  t140 = torch.nn.functional.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias)  # t140: "cuda:0 bf16[1, 16384, 2560]"
    # t140 = ltorch.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias)  # t140: "cuda:0 bf16[1, 16384, 2560]"
      # t140 = prims.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias)  # t140: "cuda:0 bf16[1, 16384, 2560]"
  [t144] = nvFusion4(t98, t140)
    # t142 = prims.convert_element_type(t98, dtypes.float32)  # t142: "cuda:0 f32[1, 16384, 2560]"
    # t141 = prims.convert_element_type(t140, dtypes.float32)  # t141: "cuda:0 f32[1, 16384, 2560]"
    # t143 = prims.add(t141, t142)  # t143: "cuda:0 f32[1, 16384, 2560]"
    # t144 = prims.convert_element_type(t143, dtypes.bfloat16)  # t144: "cuda:0 bf16[1, 16384, 2560]"
  del t98, t140
  return {'output': t144, 'flat_args': [x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight], 'flat_output': (t144,)}, ((cos, sin, t105, t110, t122, t123, t139, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t94, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x), ())
crcrpar commented 1 month ago

parallel_residual=True uses intermediates tensors of

(cos, sin, t118, t119, t121, t130, t135, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x)

while the other,

(cos, sin, t105, t110, t122, t123, t139, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t94, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x)

The differences are

(
    `t118: "cuda:0 bf16[1, 16384, 2560]"`,
    `t119: "cuda:0 bf16[1, 16384, 10240]"`,
    `t121: "cuda:0 f32[1, 16384, 10240]"`,
    `t130: "cuda:0 f32[1, 16384, 10240]"`,
    `t135: "cuda:0 bf16[1, 16384, 10240]"`,
)

vs


(
    `t105: "cuda:0 bf16[1, 16384]"`,
    `t110: "cuda:0 f32[1, 16384, 1]"`,
    `t122: "cuda:0 bf16[1, 16384, 2560]"`,
    `t123: "cuda:0 bf16[1, 16384, 10240]"`,
    `t139: "cuda:0 bf16[1, 16384, 10240]"`,
)
```.
crcrpar commented 1 month ago

https://gist.github.com/crcrpar/ce52789c933ca7013049c6eb1ba06366 has aot fwd and bwd. Backward arguments are as follows:

def forward(
    self,
    primals_1: "bf16[2560][1]cuda:0",
    primals_3: "bf16[1, 16384, 2560][41943040, 2560, 1]cuda:0",
    primals_6: "bf16[16384, 20][20, 1]cuda:0",
    primals_7: "bf16[16384, 20][20, 1]cuda:0",
    primals_10: "bf16[2560][1]cuda:0",
    getitem_1: "f32[1, 16384, 1][16384, 1, 1]cuda:0",
    rsqrt: "f32[1, 16384, 1][16384, 1, 1]cuda:0",
    view: "bf16[16384, 2560][2560, 1]cuda:0",
    view_5: "bf16[1, 32, 16384, 80][125829120, 240, 7680, 1]cuda:0",
    cat_2: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
    cat_3: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
    getitem_5: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
    getitem_6: "f32[1, 32, 16384][524288, 16384, 1]cuda:0",
    getitem_11: "i64[][]cuda:0",
    getitem_12: "i64[][]cuda:0",
    view_7: "bf16[16384, 2560][2560, 1]cuda:0",
    view_9: "bf16[16384, 2560][2560, 1]cuda:0",
    addmm_2: "bf16[16384, 10240][10240, 1]cuda:0",
    view_11: "bf16[16384, 10240][10240, 1]cuda:0",
    permute_6: "bf16[2560, 10240][10240, 1]cuda:0",
    permute_10: "bf16[10240, 2560][2560, 1]cuda:0",
    permute_14: "bf16[2560, 2560][2560, 1]cuda:0",
    permute_20: "bf16[7680, 2560][2560, 1]cuda:0",
    tangents_1: "bf16[1, 16384, 2560][41943040, 2560, 1]cuda:0",
):
    ...

getitem_1 is second of var_mean. getitem_5, getitem_6, getitem_11, and getitem_12 are torch.ops.aten._scaled_dot_product_cudnn_attention.default outputs.