Open crcrpar opened 1 month ago
Add to Q4 planning
The cause seems to be basically parallel_residual=True
as in https://github.com/Lightning-AI/lightning-thunder/issues/246#issuecomment-2302121789
Script to run litgpt.model.Block
with the config of "stablecode-completion-alpha-3b" whose parallel_residual
by default is True
.
import argparse
import gc
import torch
from litgpt import Config, GPT
from litgpt.model import Block
import thunder
def init_model(config: Config, compiler: str, device: torch.device, dtype: torch.dtype) -> GPT:
model = Block(config, 0).to(device=device, dtype=dtype)
print(model)
match compiler:
case "eager":
return model
case "thunder":
return thunder.jit(model)
case "torch.compile":
return torch.compile(model)
def print_memory_stats(header):
stats = torch.cuda.memory_stats()
print("{}| current active: {:.3f}, allocated peak: {:.3f}, current allocated: {:.3f}".format(
header,
stats["active_bytes.all.current"] / 1e9,
stats["allocated_bytes.all.peak"] / 1e9,
stats["allocated_bytes.all.current"] / 1e9,
))
def get_batch(
config: Config,
device: torch.device,
dtype: torch.dtype,
args: argparse.Namespace,
) -> tuple[torch.Tensor, torch.Tensor]:
with device:
return (
torch.randn(size=(args.batch_size, config.block_size, config.n_embd), dtype=dtype, requires_grad=True),
torch.randn(size=(config.block_size, config.rope_n_elem), dtype=dtype),
torch.randn(size=(config.block_size, config.rope_n_elem), dtype=dtype),
)
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--n-layer",
"-N",
type=int,
default=8,
help="number of layers to fit the model on RTX Ada 6000. Default value is 32.",
)
parser.add_argument("--batch-size", "-B", type=int, default=1, help="Batch size.")
parser.add_argument("--block-size", "-S", type=int, default=16384, help="Sequence length.")
parser.add_argument(
"--compiler",
"-C",
type=str,
default="eager",
help="Deep Learning Compiler to use.",
choices=("eager", "torch.compile", "thunder"),
)
parser.add_argument("--n-iter", "-I", type=int, default=3, help="Number of iterations")
parser.add_argument(
"--model-name",
"-M",
type=str,
default="stablecode-completion-alpha-3b",
choices=(
"stablecode-completion-alpha-3b",
"Llama-3-8B",
),
)
parser.add_argument("--mem-snapshot", action="store_true", default=False)
parser.add_argument("--dtype", default="bfloat16", choices=("float32", "bfloat16"))
parser.add_argument("--dump-traces", action="store_true", default=False)
parser.add_argument("--disable-parallel-residual", action="store_true", default=False)
args = parser.parse_args()
if args.mem_snapshot:
torch.cuda.memory._record_memory_history()
print("*" * 80)
print(f"* {args.n_layer=}, {args.block_size=}, {args.batch_size=}, {args.compiler=}, {args.dtype=}")
print("*" * 80)
device = torch.device("cuda")
dtype = getattr(torch, args.dtype)
model_name = args.model_name
config = Config.from_name(model_name)
config.n_layer = args.n_layer
config.block_size = args.block_size
config.parallel_residual = not args.disable_parallel_residual
print(f"#####\n{config}")
model = init_model(config, args.compiler, device, dtype)
print_memory_stats("model")
optimizer = torch.optim.AdamW(model.parameters())
print_memory_stats("model and data")
for i in range(args.n_iter):
optimizer.zero_grad()
print(f" Iter: {i + 1}")
x, cos, sin = get_batch(config, device, dtype, args)
out = model(x, cos, sin)
print_memory_stats(" model, data, and forward results")
loss = out.mean()
loss.backward()
print_memory_stats(" model, data, forward, and backward results")
optimizer.step()
print_memory_stats(" model, data, forward, backward, and optimizer results")
if args.compiler == "thunder":
compile_data = thunder.compile_data(model)
print(f"Used executors: {[e.name for e in compile_data.executors_list]}")
if args.compiler == "thunder" and args.dump_traces:
from thunder.examine.memory_caculation import get_alloc_memory
extrace = thunder.last_traces(model)[-1]
file_name = f"trace_block_of_{model_name}_{args.compiler}_{args.n_layer}.py"
if not config.parallel_residual:
file_name = f"trace_block_of_{model_name}_{args.compiler}_{args.n_layer}_no_parallel_residual.py"
with open(file_name, "w") as f:
f.write(f"{extrace}\n")
f.write(f"{thunder.last_backward_traces(model)[-1]}\n")
del x, model, optimizer
gc.collect()
print_memory_stats("After del'ing data, model, and optimizer")
file_name = f"block_of_{model_name}_{args.compiler}_{args.n_layer}.pickle"
if args.mem_snapshot:
torch.cuda.memory._dump_snapshot(file_name)
print(f"Saving snapshot into {file_name}...")
torch.cuda.memory._record_memory_history(enabled=None)
if __name__ == "__main__":
main()
parallel_residual=True
.# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight):
# x: "cuda:0 bf16[1, 16384, 2560]"
# cos: "cuda:0 bf16[16384, 20]"
# sin: "cuda:0 bf16[16384, 20]"
# t_attn_attn_bias: "cuda:0 bf16[7680]"
# t_attn_attn_weight: "cuda:0 bf16[7680, 2560]"
# t_attn_proj_bias: "cuda:0 bf16[2560]"
# t_attn_proj_weight: "cuda:0 bf16[2560, 2560]"
# t_mlp_fc_bias: "cuda:0 bf16[10240]"
# t_mlp_fc_weight: "cuda:0 bf16[10240, 2560]"
# t_mlp_proj_bias: "cuda:0 bf16[2560]"
# t_mlp_proj_weight: "cuda:0 bf16[2560, 10240]"
# t_norm_1_bias: "cuda:0 bf16[2560]"
# t_norm_1_weight: "cuda:0 bf16[2560]"
# t_norm_2_bias: "cuda:0 bf16[2560]"
# t_norm_2_weight: "cuda:0 bf16[2560]"
[t4, t8, t20, t118] = nvFusion0(x, t_norm_1_weight, t_norm_1_bias, t_norm_2_weight, t_norm_2_bias)
# t0 = prims.convert_element_type(x, dtypes.float32) # t0: "cuda:0 f32[1, 16384, 2560]"
# (t3, t4) = prims.var_mean(t0, (2,), correction=0)
# t5 = prims.broadcast_in_dim(t3, [1, 16384, 1], [0, 1]) # t5: "cuda:0 f32[1, 16384, 1]"
# t6 = prims.broadcast_in_dim(t4, [1, 16384, 1], [0, 1]) # t6: "cuda:0 f32[1, 16384, 1]"
# t7 = prims.add(t5, 1e-05) # t7: "cuda:0 f32[1, 16384, 1]"
# t9 = prims.broadcast_in_dim(t6, (1, 16384, 2560), (0, 1, 2)) # t9: "cuda:0 f32[1, 16384, 2560]"
# t8 = prims.rsqrt(t7) # t8: "cuda:0 f32[1, 16384, 1]"
# t11 = prims.sub(t0, t9) # t11: "cuda:0 f32[1, 16384, 2560]"
# t12 = prims.broadcast_in_dim(t8, (1, 16384, 2560), (0, 1, 2)) # t12: "cuda:0 f32[1, 16384, 2560]"
# t13 = prims.mul(t11, t12) # t13: "cuda:0 f32[1, 16384, 2560]"
# t14 = prims.broadcast_in_dim(t_norm_1_weight, (1, 16384, 2560), (2,)) # t14: "cuda:0 bf16[1, 16384, 2560]"
# t15 = prims.convert_element_type(t14, dtypes.float32) # t15: "cuda:0 f32[1, 16384, 2560]"
# t16 = prims.mul(t13, t15) # t16: "cuda:0 f32[1, 16384, 2560]"
# t17 = prims.broadcast_in_dim(t_norm_1_bias, (1, 16384, 2560), (2,)) # t17: "cuda:0 bf16[1, 16384, 2560]"
# t18 = prims.convert_element_type(t17, dtypes.float32) # t18: "cuda:0 f32[1, 16384, 2560]"
# t19 = prims.add(t16, t18) # t19: "cuda:0 f32[1, 16384, 2560]"
# t20 = prims.convert_element_type(t19, dtypes.bfloat16) # t20: "cuda:0 bf16[1, 16384, 2560]"
# t112 = prims.broadcast_in_dim(t_norm_2_weight, (1, 16384, 2560), (2,)) # t112: "cuda:0 bf16[1, 16384, 2560]"
# t113 = prims.convert_element_type(t112, dtypes.float32) # t113: "cuda:0 f32[1, 16384, 2560]"
# t114 = prims.mul(t13, t113) # t114: "cuda:0 f32[1, 16384, 2560]"
# t115 = prims.broadcast_in_dim(t_norm_2_bias, (1, 16384, 2560), (2,)) # t115: "cuda:0 bf16[1, 16384, 2560]"
# t116 = prims.convert_element_type(t115, dtypes.float32) # t116: "cuda:0 f32[1, 16384, 2560]"
# t117 = prims.add(t114, t116) # t117: "cuda:0 f32[1, 16384, 2560]"
# t118 = prims.convert_element_type(t117, dtypes.bfloat16) # t118: "cuda:0 bf16[1, 16384, 2560]"
t21 = torch.nn.functional.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
# t21 = ltorch.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
# t21 = prims.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
t119 = torch.nn.functional.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias) # t119: "cuda:0 bf16[1, 16384, 10240]"
# t119 = ltorch.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias) # t119: "cuda:0 bf16[1, 16384, 10240]"
# t119 = prims.linear(t118, t_mlp_fc_weight, t_mlp_fc_bias) # t119: "cuda:0 bf16[1, 16384, 10240]"
[t39, t84, t87, t121, t130, t135] = TorchCompile0(t21, cos, sin, t119)
# t22 = prims.reshape(t21, (1, 16384, 32, 3, 80)) # t22: "cuda:0 bf16[1, 16384, 32, 3, 80]"
# t23 = prims.transpose(t22, (0, 2, 3, 1, 4)) # t23: "cuda:0 bf16[1, 32, 3, 16384, 80]"
# (t24, t25, t26) = ltorch.split(t23, (1, 1, 1), 2)
# t24 = prims.slice_prim(t23, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 80], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t25 = prims.slice_prim(t23, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 80], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t26 = prims.slice_prim(t23, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 80], [1, 1, 1, 1, 1]) # t26: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t27 = prims.reshape(t24, (1, 32, 16384, 80)) # t27: "cuda:0 bf16[1, 32, 16384, 80]"
# t33 = prims.reshape(t25, (1, 32, 16384, 80)) # t33: "cuda:0 bf16[1, 32, 16384, 80]"
# t39 = prims.reshape(t26, (1, 32, 16384, 80)) # t39: "cuda:0 bf16[1, 32, 16384, 80]"
# t40 = prims.slice_prim(t27, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 16384, 20]"
# t41 = prims.slice_prim(t40, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 16384, 10]"
# t42 = prims.slice_prim(t40, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1]) # t42: "cuda:0 bf16[1, 32, 16384, 10]"
# t43 = prims.convert_element_type(t42, dtypes.float32) # t43: "cuda:0 f32[1, 32, 16384, 10]"
# t44 = prims.neg(t43) # t44: "cuda:0 f32[1, 32, 16384, 10]"
# t45 = prims.convert_element_type(t44, dtypes.bfloat16) # t45: "cuda:0 bf16[1, 32, 16384, 10]"
# t46 = prims.cat([t45, t41], -1) # t46: "cuda:0 bf16[1, 32, 16384, 20]"
# t47 = prims.broadcast_in_dim(cos, (1, 32, 16384, 20), (2, 3)) # t47: "cuda:0 bf16[1, 32, 16384, 20]"
# t48 = prims.convert_element_type(t40, dtypes.float32) # t48: "cuda:0 f32[1, 32, 16384, 20]"
# t49 = prims.convert_element_type(t47, dtypes.float32) # t49: "cuda:0 f32[1, 32, 16384, 20]"
# t50 = ltorch.mul(t48, t49) # t50: "cuda:0 f32[1, 32, 16384, 20]"
# t50 = prims.mul(t48, t49) # t50: "cuda:0 f32[1, 32, 16384, 20]"
# t51 = prims.convert_element_type(t50, dtypes.bfloat16) # t51: "cuda:0 bf16[1, 32, 16384, 20]"
# t52 = prims.broadcast_in_dim(sin, (1, 32, 16384, 20), (2, 3)) # t52: "cuda:0 bf16[1, 32, 16384, 20]"
# t53 = prims.convert_element_type(t46, dtypes.float32) # t53: "cuda:0 f32[1, 32, 16384, 20]"
# t54 = prims.convert_element_type(t52, dtypes.float32) # t54: "cuda:0 f32[1, 32, 16384, 20]"
# t55 = ltorch.mul(t53, t54) # t55: "cuda:0 f32[1, 32, 16384, 20]"
# t55 = prims.mul(t53, t54) # t55: "cuda:0 f32[1, 32, 16384, 20]"
# t56 = prims.convert_element_type(t55, dtypes.bfloat16) # t56: "cuda:0 bf16[1, 32, 16384, 20]"
# t59 = ltorch.add(t50, t55, alpha=None) # t59: "cuda:0 f32[1, 32, 16384, 20]"
# t59 = prims.add(t50, t55) # t59: "cuda:0 f32[1, 32, 16384, 20]"
# t60 = prims.convert_element_type(t59, dtypes.bfloat16) # t60: "cuda:0 bf16[1, 32, 16384, 20]"
# t61 = prims.slice_prim(t33, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1]) # t61: "cuda:0 bf16[1, 32, 16384, 20]"
# t62 = prims.slice_prim(t61, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1]) # t62: "cuda:0 bf16[1, 32, 16384, 10]"
# t63 = prims.slice_prim(t61, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1]) # t63: "cuda:0 bf16[1, 32, 16384, 10]"
# t64 = prims.convert_element_type(t63, dtypes.float32) # t64: "cuda:0 f32[1, 32, 16384, 10]"
# t65 = prims.neg(t64) # t65: "cuda:0 f32[1, 32, 16384, 10]"
# t66 = prims.convert_element_type(t65, dtypes.bfloat16) # t66: "cuda:0 bf16[1, 32, 16384, 10]"
# t68 = prims.cat([t66, t62], -1) # t68: "cuda:0 bf16[1, 32, 16384, 20]"
# t70 = prims.convert_element_type(t61, dtypes.float32) # t70: "cuda:0 f32[1, 32, 16384, 20]"
# t72 = ltorch.mul(t70, t49) # t72: "cuda:0 f32[1, 32, 16384, 20]"
# t72 = prims.mul(t70, t49) # t72: "cuda:0 f32[1, 32, 16384, 20]"
# t73 = prims.convert_element_type(t72, dtypes.bfloat16) # t73: "cuda:0 bf16[1, 32, 16384, 20]"
# t75 = prims.convert_element_type(t68, dtypes.float32) # t75: "cuda:0 f32[1, 32, 16384, 20]"
# t77 = ltorch.mul(t75, t54) # t77: "cuda:0 f32[1, 32, 16384, 20]"
# t77 = prims.mul(t75, t54) # t77: "cuda:0 f32[1, 32, 16384, 20]"
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 32, 16384, 20]"
# t81 = ltorch.add(t72, t77, alpha=None) # t81: "cuda:0 f32[1, 32, 16384, 20]"
# t81 = prims.add(t72, t77) # t81: "cuda:0 f32[1, 32, 16384, 20]"
# t82 = prims.convert_element_type(t81, dtypes.bfloat16) # t82: "cuda:0 bf16[1, 32, 16384, 20]"
# t83 = prims.slice_prim(t27, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1]) # t83: "cuda:0 bf16[1, 32, 16384, 60]"
# t84 = prims.cat([t60, t83], -1) # t84: "cuda:0 bf16[1, 32, 16384, 80]"
# t85 = prims.slice_prim(t33, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1]) # t85: "cuda:0 bf16[1, 32, 16384, 60]"
# t87 = prims.cat([t82, t85], -1) # t87: "cuda:0 bf16[1, 32, 16384, 80]"
# t120 = prims.convert_element_type(t119, dtypes.float32) # t120: "cuda:0 f32[1, 16384, 10240]"
# t121 = ltorch.true_divide(t120, 1.4142135623730951) # t121: "cuda:0 f32[1, 16384, 10240]"
# t121 = prims.div(t120, 1.4142135623730951) # t121: "cuda:0 f32[1, 16384, 10240]"
# t122 = prims.convert_element_type(t121, dtypes.bfloat16) # t122: "cuda:0 bf16[1, 16384, 10240]"
# t124 = prims.erf(t121) # t124: "cuda:0 f32[1, 16384, 10240]"
# t125 = prims.convert_element_type(t124, dtypes.bfloat16) # t125: "cuda:0 bf16[1, 16384, 10240]"
# t127 = ltorch.mul(0.5, t124) # t127: "cuda:0 f32[1, 16384, 10240]"
# t127 = prims.mul(0.5, t124) # t127: "cuda:0 f32[1, 16384, 10240]"
# t128 = prims.convert_element_type(t127, dtypes.bfloat16) # t128: "cuda:0 bf16[1, 16384, 10240]"
# t130 = ltorch.add(0.5, t127, alpha=None) # t130: "cuda:0 f32[1, 16384, 10240]"
# t130 = prims.add(0.5, t127) # t130: "cuda:0 f32[1, 16384, 10240]"
# t131 = prims.convert_element_type(t130, dtypes.bfloat16) # t131: "cuda:0 bf16[1, 16384, 10240]"
# t134 = ltorch.mul(t120, t130) # t134: "cuda:0 f32[1, 16384, 10240]"
# t134 = prims.mul(t120, t130) # t134: "cuda:0 f32[1, 16384, 10240]"
# t135 = prims.convert_element_type(t134, dtypes.bfloat16) # t135: "cuda:0 bf16[1, 16384, 10240]"
del t21
(t88, t89, t90, t91) = cudnn_sdpa_fwd(t84, t87, t39, None, 0.0, True, scale=0.11180339887498948)
t136 = torch.nn.functional.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias) # t136: "cuda:0 bf16[1, 16384, 2560]"
# t136 = ltorch.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias) # t136: "cuda:0 bf16[1, 16384, 2560]"
# t136 = prims.linear(t135, t_mlp_proj_weight, t_mlp_proj_bias) # t136: "cuda:0 bf16[1, 16384, 2560]"
[t93] = nvFusion1(t88)
# t92 = prims.transpose(t88, (0, 2, 1, 3)) # t92: "cuda:0 bf16[1, 16384, 32, 80]"
# t93 = prims.reshape(t92, (1, 16384, 2560)) # t93: "cuda:0 bf16[1, 16384, 2560]"
t94 = torch.nn.functional.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
# t94 = ltorch.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
# t94 = prims.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
[t144] = nvFusion2(t136, t94, x)
# t137 = prims.convert_element_type(t136, dtypes.float32) # t137: "cuda:0 f32[1, 16384, 2560]"
# t138 = prims.convert_element_type(t94, dtypes.float32) # t138: "cuda:0 f32[1, 16384, 2560]"
# t139 = prims.add(t137, t138) # t139: "cuda:0 f32[1, 16384, 2560]"
# t142 = prims.convert_element_type(x, dtypes.float32) # t142: "cuda:0 f32[1, 16384, 2560]"
# t143 = prims.add(t139, t142) # t143: "cuda:0 f32[1, 16384, 2560]"
# t144 = prims.convert_element_type(t143, dtypes.bfloat16) # t144: "cuda:0 bf16[1, 16384, 2560]"
del t136, t94
return {'output': t144, 'flat_args': [x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight], 'flat_output': (t144,)}, ((cos, sin, t118, t119, t121, t130, t135, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x), ())
parallel_residual
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight):
# x: "cuda:0 bf16[1, 16384, 2560]"
# cos: "cuda:0 bf16[16384, 20]"
# sin: "cuda:0 bf16[16384, 20]"
# t_attn_attn_bias: "cuda:0 bf16[7680]"
# t_attn_attn_weight: "cuda:0 bf16[7680, 2560]"
# t_attn_proj_bias: "cuda:0 bf16[2560]"
# t_attn_proj_weight: "cuda:0 bf16[2560, 2560]"
# t_mlp_fc_bias: "cuda:0 bf16[10240]"
# t_mlp_fc_weight: "cuda:0 bf16[10240, 2560]"
# t_mlp_proj_bias: "cuda:0 bf16[2560]"
# t_mlp_proj_weight: "cuda:0 bf16[2560, 10240]"
# t_norm_1_bias: "cuda:0 bf16[2560]"
# t_norm_1_weight: "cuda:0 bf16[2560]"
# t_norm_2_bias: "cuda:0 bf16[2560]"
# t_norm_2_weight: "cuda:0 bf16[2560]"
[t4, t8, t20] = nvFusion0(x, t_norm_1_weight, t_norm_1_bias)
# t0 = prims.convert_element_type(x, dtypes.float32) # t0: "cuda:0 f32[1, 16384, 2560]"
# (t3, t4) = prims.var_mean(t0, (2,), correction=0)
# t5 = prims.broadcast_in_dim(t3, [1, 16384, 1], [0, 1]) # t5: "cuda:0 f32[1, 16384, 1]"
# t6 = prims.broadcast_in_dim(t4, [1, 16384, 1], [0, 1]) # t6: "cuda:0 f32[1, 16384, 1]"
# t7 = prims.add(t5, 1e-05) # t7: "cuda:0 f32[1, 16384, 1]"
# t9 = prims.broadcast_in_dim(t6, (1, 16384, 2560), (0, 1, 2)) # t9: "cuda:0 f32[1, 16384, 2560]"
# t8 = prims.rsqrt(t7) # t8: "cuda:0 f32[1, 16384, 1]"
# t11 = prims.sub(t0, t9) # t11: "cuda:0 f32[1, 16384, 2560]"
# t12 = prims.broadcast_in_dim(t8, (1, 16384, 2560), (0, 1, 2)) # t12: "cuda:0 f32[1, 16384, 2560]"
# t13 = prims.mul(t11, t12) # t13: "cuda:0 f32[1, 16384, 2560]"
# t14 = prims.broadcast_in_dim(t_norm_1_weight, (1, 16384, 2560), (2,)) # t14: "cuda:0 bf16[1, 16384, 2560]"
# t15 = prims.convert_element_type(t14, dtypes.float32) # t15: "cuda:0 f32[1, 16384, 2560]"
# t16 = prims.mul(t13, t15) # t16: "cuda:0 f32[1, 16384, 2560]"
# t17 = prims.broadcast_in_dim(t_norm_1_bias, (1, 16384, 2560), (2,)) # t17: "cuda:0 bf16[1, 16384, 2560]"
# t18 = prims.convert_element_type(t17, dtypes.float32) # t18: "cuda:0 f32[1, 16384, 2560]"
# t19 = prims.add(t16, t18) # t19: "cuda:0 f32[1, 16384, 2560]"
# t20 = prims.convert_element_type(t19, dtypes.bfloat16) # t20: "cuda:0 bf16[1, 16384, 2560]"
t21 = torch.nn.functional.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
# t21 = ltorch.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
# t21 = prims.linear(t20, t_attn_attn_weight, t_attn_attn_bias) # t21: "cuda:0 bf16[1, 16384, 7680]"
[t39, t84, t87] = TorchCompile0(t21, cos, sin)
# t22 = prims.reshape(t21, (1, 16384, 32, 3, 80)) # t22: "cuda:0 bf16[1, 16384, 32, 3, 80]"
# t23 = prims.transpose(t22, (0, 2, 3, 1, 4)) # t23: "cuda:0 bf16[1, 32, 3, 16384, 80]"
# (t24, t25, t26) = ltorch.split(t23, (1, 1, 1), 2)
# t24 = prims.slice_prim(t23, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 80], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t25 = prims.slice_prim(t23, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 80], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t26 = prims.slice_prim(t23, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 80], [1, 1, 1, 1, 1]) # t26: "cuda:0 bf16[1, 32, 1, 16384, 80]"
# t27 = prims.reshape(t24, (1, 32, 16384, 80)) # t27: "cuda:0 bf16[1, 32, 16384, 80]"
# t33 = prims.reshape(t25, (1, 32, 16384, 80)) # t33: "cuda:0 bf16[1, 32, 16384, 80]"
# t39 = prims.reshape(t26, (1, 32, 16384, 80)) # t39: "cuda:0 bf16[1, 32, 16384, 80]"
# t40 = prims.slice_prim(t27, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 16384, 20]"
# t41 = prims.slice_prim(t40, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 16384, 10]"
# t42 = prims.slice_prim(t40, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1]) # t42: "cuda:0 bf16[1, 32, 16384, 10]"
# t43 = prims.convert_element_type(t42, dtypes.float32) # t43: "cuda:0 f32[1, 32, 16384, 10]"
# t44 = prims.neg(t43) # t44: "cuda:0 f32[1, 32, 16384, 10]"
# t45 = prims.convert_element_type(t44, dtypes.bfloat16) # t45: "cuda:0 bf16[1, 32, 16384, 10]"
# t46 = prims.cat([t45, t41], -1) # t46: "cuda:0 bf16[1, 32, 16384, 20]"
# t47 = prims.broadcast_in_dim(cos, (1, 32, 16384, 20), (2, 3)) # t47: "cuda:0 bf16[1, 32, 16384, 20]"
# t48 = prims.convert_element_type(t40, dtypes.float32) # t48: "cuda:0 f32[1, 32, 16384, 20]"
# t49 = prims.convert_element_type(t47, dtypes.float32) # t49: "cuda:0 f32[1, 32, 16384, 20]"
# t50 = ltorch.mul(t48, t49) # t50: "cuda:0 f32[1, 32, 16384, 20]"
# t50 = prims.mul(t48, t49) # t50: "cuda:0 f32[1, 32, 16384, 20]"
# t51 = prims.convert_element_type(t50, dtypes.bfloat16) # t51: "cuda:0 bf16[1, 32, 16384, 20]"
# t52 = prims.broadcast_in_dim(sin, (1, 32, 16384, 20), (2, 3)) # t52: "cuda:0 bf16[1, 32, 16384, 20]"
# t53 = prims.convert_element_type(t46, dtypes.float32) # t53: "cuda:0 f32[1, 32, 16384, 20]"
# t54 = prims.convert_element_type(t52, dtypes.float32) # t54: "cuda:0 f32[1, 32, 16384, 20]"
# t55 = ltorch.mul(t53, t54) # t55: "cuda:0 f32[1, 32, 16384, 20]"
# t55 = prims.mul(t53, t54) # t55: "cuda:0 f32[1, 32, 16384, 20]"
# t56 = prims.convert_element_type(t55, dtypes.bfloat16) # t56: "cuda:0 bf16[1, 32, 16384, 20]"
# t59 = ltorch.add(t50, t55, alpha=None) # t59: "cuda:0 f32[1, 32, 16384, 20]"
# t59 = prims.add(t50, t55) # t59: "cuda:0 f32[1, 32, 16384, 20]"
# t60 = prims.convert_element_type(t59, dtypes.bfloat16) # t60: "cuda:0 bf16[1, 32, 16384, 20]"
# t61 = prims.slice_prim(t33, [0, 0, 0, 0], [1, 32, 16384, 20], [1, 1, 1, 1]) # t61: "cuda:0 bf16[1, 32, 16384, 20]"
# t62 = prims.slice_prim(t61, [0, 0, 0, 0], [1, 32, 16384, 10], [1, 1, 1, 1]) # t62: "cuda:0 bf16[1, 32, 16384, 10]"
# t63 = prims.slice_prim(t61, [0, 0, 0, 10], [1, 32, 16384, 20], [1, 1, 1, 1]) # t63: "cuda:0 bf16[1, 32, 16384, 10]"
# t64 = prims.convert_element_type(t63, dtypes.float32) # t64: "cuda:0 f32[1, 32, 16384, 10]"
# t65 = prims.neg(t64) # t65: "cuda:0 f32[1, 32, 16384, 10]"
# t66 = prims.convert_element_type(t65, dtypes.bfloat16) # t66: "cuda:0 bf16[1, 32, 16384, 10]"
# t68 = prims.cat([t66, t62], -1) # t68: "cuda:0 bf16[1, 32, 16384, 20]"
# t70 = prims.convert_element_type(t61, dtypes.float32) # t70: "cuda:0 f32[1, 32, 16384, 20]"
# t72 = ltorch.mul(t70, t49) # t72: "cuda:0 f32[1, 32, 16384, 20]"
# t72 = prims.mul(t70, t49) # t72: "cuda:0 f32[1, 32, 16384, 20]"
# t73 = prims.convert_element_type(t72, dtypes.bfloat16) # t73: "cuda:0 bf16[1, 32, 16384, 20]"
# t75 = prims.convert_element_type(t68, dtypes.float32) # t75: "cuda:0 f32[1, 32, 16384, 20]"
# t77 = ltorch.mul(t75, t54) # t77: "cuda:0 f32[1, 32, 16384, 20]"
# t77 = prims.mul(t75, t54) # t77: "cuda:0 f32[1, 32, 16384, 20]"
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 32, 16384, 20]"
# t81 = ltorch.add(t72, t77, alpha=None) # t81: "cuda:0 f32[1, 32, 16384, 20]"
# t81 = prims.add(t72, t77) # t81: "cuda:0 f32[1, 32, 16384, 20]"
# t82 = prims.convert_element_type(t81, dtypes.bfloat16) # t82: "cuda:0 bf16[1, 32, 16384, 20]"
# t83 = prims.slice_prim(t27, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1]) # t83: "cuda:0 bf16[1, 32, 16384, 60]"
# t84 = prims.cat([t60, t83], -1) # t84: "cuda:0 bf16[1, 32, 16384, 80]"
# t85 = prims.slice_prim(t33, [0, 0, 0, 20], [1, 32, 16384, 80], [1, 1, 1, 1]) # t85: "cuda:0 bf16[1, 32, 16384, 60]"
# t87 = prims.cat([t82, t85], -1) # t87: "cuda:0 bf16[1, 32, 16384, 80]"
del t21
(t88, t89, t90, t91) = cudnn_sdpa_fwd(t84, t87, t39, None, 0.0, True, scale=0.11180339887498948)
[t93] = nvFusion1(t88)
# t92 = prims.transpose(t88, (0, 2, 1, 3)) # t92: "cuda:0 bf16[1, 16384, 32, 80]"
# t93 = prims.reshape(t92, (1, 16384, 2560)) # t93: "cuda:0 bf16[1, 16384, 2560]"
t94 = torch.nn.functional.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
# t94 = ltorch.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
# t94 = prims.linear(t93, t_attn_proj_weight, t_attn_proj_bias) # t94: "cuda:0 bf16[1, 16384, 2560]"
[t98, t105, t110, t122] = nvFusion2(t94, x, t_norm_2_weight, t_norm_2_bias)
# t95 = prims.convert_element_type(t94, dtypes.float32) # t95: "cuda:0 f32[1, 16384, 2560]"
# t96 = prims.convert_element_type(x, dtypes.float32) # t96: "cuda:0 f32[1, 16384, 2560]"
# t97 = prims.add(t95, t96) # t97: "cuda:0 f32[1, 16384, 2560]"
# t98 = prims.convert_element_type(t97, dtypes.bfloat16) # t98: "cuda:0 bf16[1, 16384, 2560]"
# (t104, t105) = prims.var_mean(t97, (2,), correction=0)
# t106 = prims.broadcast_in_dim(t104, [1, 16384, 1], [0, 1]) # t106: "cuda:0 f32[1, 16384, 1]"
# t107 = prims.broadcast_in_dim(t105, [1, 16384, 1], [0, 1]) # t107: "cuda:0 f32[1, 16384, 1]"
# t109 = prims.add(t106, 1e-05) # t109: "cuda:0 f32[1, 16384, 1]"
# t111 = prims.broadcast_in_dim(t107, (1, 16384, 2560), (0, 1, 2)) # t111: "cuda:0 f32[1, 16384, 2560]"
# t110 = prims.rsqrt(t109) # t110: "cuda:0 f32[1, 16384, 1]"
# t113 = prims.sub(t97, t111) # t113: "cuda:0 f32[1, 16384, 2560]"
# t114 = prims.broadcast_in_dim(t110, (1, 16384, 2560), (0, 1, 2)) # t114: "cuda:0 f32[1, 16384, 2560]"
# t115 = prims.mul(t113, t114) # t115: "cuda:0 f32[1, 16384, 2560]"
# t116 = prims.broadcast_in_dim(t_norm_2_weight, (1, 16384, 2560), (2,)) # t116: "cuda:0 bf16[1, 16384, 2560]"
# t117 = prims.convert_element_type(t116, dtypes.float32) # t117: "cuda:0 f32[1, 16384, 2560]"
# t118 = prims.mul(t115, t117) # t118: "cuda:0 f32[1, 16384, 2560]"
# t119 = prims.broadcast_in_dim(t_norm_2_bias, (1, 16384, 2560), (2,)) # t119: "cuda:0 bf16[1, 16384, 2560]"
# t120 = prims.convert_element_type(t119, dtypes.float32) # t120: "cuda:0 f32[1, 16384, 2560]"
# t121 = prims.add(t118, t120) # t121: "cuda:0 f32[1, 16384, 2560]"
# t122 = prims.convert_element_type(t121, dtypes.bfloat16) # t122: "cuda:0 bf16[1, 16384, 2560]"
t123 = torch.nn.functional.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias) # t123: "cuda:0 bf16[1, 16384, 10240]"
# t123 = ltorch.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias) # t123: "cuda:0 bf16[1, 16384, 10240]"
# t123 = prims.linear(t122, t_mlp_fc_weight, t_mlp_fc_bias) # t123: "cuda:0 bf16[1, 16384, 10240]"
[t139] = nvFusion3(t123)
# t124 = prims.convert_element_type(t123, dtypes.float32) # t124: "cuda:0 f32[1, 16384, 10240]"
# t125 = prims.div(t124, 1.4142135623730951) # t125: "cuda:0 f32[1, 16384, 10240]"
# t128 = prims.erf(t125) # t128: "cuda:0 f32[1, 16384, 10240]"
# t131 = prims.mul(0.5, t128) # t131: "cuda:0 f32[1, 16384, 10240]"
# t134 = prims.add(0.5, t131) # t134: "cuda:0 f32[1, 16384, 10240]"
# t138 = prims.mul(t124, t134) # t138: "cuda:0 f32[1, 16384, 10240]"
# t139 = prims.convert_element_type(t138, dtypes.bfloat16) # t139: "cuda:0 bf16[1, 16384, 10240]"
t140 = torch.nn.functional.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias) # t140: "cuda:0 bf16[1, 16384, 2560]"
# t140 = ltorch.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias) # t140: "cuda:0 bf16[1, 16384, 2560]"
# t140 = prims.linear(t139, t_mlp_proj_weight, t_mlp_proj_bias) # t140: "cuda:0 bf16[1, 16384, 2560]"
[t144] = nvFusion4(t98, t140)
# t142 = prims.convert_element_type(t98, dtypes.float32) # t142: "cuda:0 f32[1, 16384, 2560]"
# t141 = prims.convert_element_type(t140, dtypes.float32) # t141: "cuda:0 f32[1, 16384, 2560]"
# t143 = prims.add(t141, t142) # t143: "cuda:0 f32[1, 16384, 2560]"
# t144 = prims.convert_element_type(t143, dtypes.bfloat16) # t144: "cuda:0 bf16[1, 16384, 2560]"
del t98, t140
return {'output': t144, 'flat_args': [x, cos, sin, t_attn_attn_bias, t_attn_attn_weight, t_attn_proj_bias, t_attn_proj_weight, t_mlp_fc_bias, t_mlp_fc_weight, t_mlp_proj_bias, t_mlp_proj_weight, t_norm_1_bias, t_norm_1_weight, t_norm_2_bias, t_norm_2_weight], 'flat_output': (t144,)}, ((cos, sin, t105, t110, t122, t123, t139, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t94, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x), ())
parallel_residual=True
uses intermediates tensors of
(cos, sin, t118, t119, t121, t130, t135, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x)
while the other,
(cos, sin, t105, t110, t122, t123, t139, t20, t39, t4, t8, t84, t87, t88, t89, t90, t91, t93, t94, t_attn_attn_weight, t_attn_proj_weight, t_mlp_fc_weight, t_mlp_proj_weight, t_norm_1_weight, t_norm_2_weight, x)
The differences are
(
`t118: "cuda:0 bf16[1, 16384, 2560]"`,
`t119: "cuda:0 bf16[1, 16384, 10240]"`,
`t121: "cuda:0 f32[1, 16384, 10240]"`,
`t130: "cuda:0 f32[1, 16384, 10240]"`,
`t135: "cuda:0 bf16[1, 16384, 10240]"`,
)
vs
(
`t105: "cuda:0 bf16[1, 16384]"`,
`t110: "cuda:0 f32[1, 16384, 1]"`,
`t122: "cuda:0 bf16[1, 16384, 2560]"`,
`t123: "cuda:0 bf16[1, 16384, 10240]"`,
`t139: "cuda:0 bf16[1, 16384, 10240]"`,
)
```.
https://gist.github.com/crcrpar/ce52789c933ca7013049c6eb1ba06366 has aot fwd and bwd. Backward arguments are as follows:
def forward(
self,
primals_1: "bf16[2560][1]cuda:0",
primals_3: "bf16[1, 16384, 2560][41943040, 2560, 1]cuda:0",
primals_6: "bf16[16384, 20][20, 1]cuda:0",
primals_7: "bf16[16384, 20][20, 1]cuda:0",
primals_10: "bf16[2560][1]cuda:0",
getitem_1: "f32[1, 16384, 1][16384, 1, 1]cuda:0",
rsqrt: "f32[1, 16384, 1][16384, 1, 1]cuda:0",
view: "bf16[16384, 2560][2560, 1]cuda:0",
view_5: "bf16[1, 32, 16384, 80][125829120, 240, 7680, 1]cuda:0",
cat_2: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
cat_3: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
getitem_5: "bf16[1, 32, 16384, 80][41943040, 1310720, 80, 1]cuda:0",
getitem_6: "f32[1, 32, 16384][524288, 16384, 1]cuda:0",
getitem_11: "i64[][]cuda:0",
getitem_12: "i64[][]cuda:0",
view_7: "bf16[16384, 2560][2560, 1]cuda:0",
view_9: "bf16[16384, 2560][2560, 1]cuda:0",
addmm_2: "bf16[16384, 10240][10240, 1]cuda:0",
view_11: "bf16[16384, 10240][10240, 1]cuda:0",
permute_6: "bf16[2560, 10240][10240, 1]cuda:0",
permute_10: "bf16[10240, 2560][2560, 1]cuda:0",
permute_14: "bf16[2560, 2560][2560, 1]cuda:0",
permute_20: "bf16[7680, 2560][2560, 1]cuda:0",
tangents_1: "bf16[1, 16384, 2560][41943040, 2560, 1]cuda:0",
):
...
getitem_1
is second of var_mean
. getitem_5
, getitem_6
, getitem_11
, and getitem_12
are torch.ops.aten._scaled_dot_product_cudnn_attention.default
outputs.
🐛 Bug
When input sequences get longer, Thunder seems to tend to use more memory than eager and torch.compile.
Let's take litgpt's
stablecode-completion-alpha-3b
as an example whose sequence length (Config.block_size
) is 16384.With the following table, Thunder's memory consumption can be more prone to sequence length
To Reproduce
Apply a diff like this and run commands like
Code sample
Expected behavior
Environment
pjnl-20240919