Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 81 forks source link

[lit-GPT] Thunder with torch.compile executor performs consistently worse than Thunder on all model sizes/batch sizes on Pythia models #256

Closed parthmannan closed 5 months ago

parthmannan commented 9 months ago

🐛 Bug

The performance of using the hybridized torch.compile executor w/ Thunder is worse than plain Thunder on Pythia models. These set of models differ from LLaMa architecture in few main ways -

  1. Use LayerNorm instead of RMSNorm
  2. Use GeLU instead of 'SiLU(x) * x`
  3. Uses parallel residual (i.e. the MLP block is computed with an input computed before the Attention block, not after)

Example performance on H100 Single Node FP16 for Pythia6.9B, MBS=1, GBS=8, FSDP ZeRO2 w/o bucketing Thunder iteration time (ms) = 232.74 ms Thunder + torch.compile iteration time (ms) = 239.23 ms

cc @crcrpar @apaz-cli

parthmannan commented 9 months ago

Okay so this is not exactly completely well understood by me yet but I am going to put out my thoughts here -

The performance penalty is largely coming from the backward pass where the parallel computation of GeLU is creating some interesting behavior. Without diving too deep into the derivatives, here's dGeLU computation as I understand it and I am separating portions below.

dGeLU(x) = (erf(x/1.4) + 1)/2 * dy + x*PDF(X=x)*dy
                 = A + (((2 * dy * x * f339)/1.77) * exp(-(x/1.4)^2)) / f337
                 = A + (B * C) / f337

Now the Torch.compile hybridized executor trace looks very interesting here (trace attached for a single layer of the network) -

  1. It computes C separately (nvFusion0 in trace)
  2. It computes B separately (nvFusion2 in trace)
  3. It computes B*C separately and that is taken over by TorchCompile0
  4. It brings A from forward pass and computes A + (B * C) / f337 (nvFusion3)

All of this is a single fusion block when we are not using Torch.compile executor. Perhaps, this happens because TorchCompile0 takes away the B*C computation before nvFuser can form blocks around the computation? Now both nvFusion2 and nvFusion3 are doing some other computation as well so the performance penalty may be not as high (idk yet) but there is definitely some extra memory transfers happening here due to the pass around of all these separate computations and creating worse performance.

thunder_inductor_BWD_trace_1_layer.txt

parthmannan commented 9 months ago

@IvanYashchuk @mruberry @tfogal - Maybe it is too early in the analysis to get your thoughts but this felt a little crucial to Torch.compile executor perf on other models.

IvanYashchuk commented 9 months ago

Perhaps, this happens because TorchCompile0 takes away the B*C computation before nvFuser can form blocks around the computation?

Yes, that's precisely what's happening. The partitioner is too aggressive and for the TorchCompile region the rule is "find a cat operation then expand the fusion group to all supported operations (reshapes, slices, add, mul) dataflow can reach" and this now destroys the fusion opportunity for nvFuser because there's no communication between the two. If we changed the order of executor to place nvFuser executor before torch.compile executor then I think what would happen is that we would see a single fusion block for dGeLU with nvFuser and just cat would be sent to the TorchCompile region.

IvanYashchuk commented 9 months ago

There are bugs in the TorchCompile partitioner logic because from the trace txt file t1008 should not have been put into TorchCompile0, it's not used there. If we fix this then there should be a single nvFuser region created.

parthmannan commented 9 months ago

This is probably not a super high priority item but I am assuming people will try the executor on different models and could potentially see worse performance? Do you already have an idea on how to fix the TorchCompile partitioner not picking up nodes that don't directly form a part of its neighborhood computation graph?

parthmannan commented 8 months ago

@IvanYashchuk Is this issue okay to move to the new open source repo?

IvanYashchuk commented 8 months ago

Yes, I think it's okay to move the new repo and it's important to not forger about this problem. I don't have concrete ideas on how to fix the partitioner.

IvanYashchuk commented 6 months ago

Checking the performance today I see the following: Thunder+TorchCompileCat: 247.54 ms Thunder: 251.69 ms

torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=8 --model_name=pythia-6.9b
torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=8 --model_name=pythia-6.9b

The numbers are worse than in Feb, but TorchCompileCat brings some value here, but that could be also due to regressions in fsdp. The major difference, I think, from Feb is that nvFuser now fuses cat since https://github.com/Lightning-AI/lightning-thunder/pull/35.

Here are the numbers for single GPU execution, TorchCompileCat still improves performance a bit: Thunder+TorchCompileCat: 208.25 ms Thunder: 211.72 ms

python thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --micro_batch_size=1 --model_name=pythia-6.9b
python thunder/benchmarks/benchmark_litgpt.py --compile=thunder --micro_batch_size=1 --model_name=pythia-6.9b

TorchCompileCat is a hack for executing RoPE fusions, it's by accident that this executor also claims backward of torch.split (that is a cat op) and consequently breaks nvFuser fusions. TorchCompileCat should be further constrained to be applied just for RoPE. And we should reevaluate its performance on a wider range of models and microbatch sizes.

Let's try constraining it a bit:

diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py
index 1e18c42e..3f20136c 100644
--- a/thunder/executors/torch_compile.py
+++ b/thunder/executors/torch_compile.py
@@ -198,10 +198,10 @@ from thunder.executors.torchex import ex as pytorch_ex
 # since they would be competing over fusion opportunities. The advantage over simply doing `torch.compile` is that you
 # still get all of Thunder's advantages, like enabling custom executors (e.g. with custom triton kernels) before it.
 required_ops = {
-    "torch.cat",
+    #"torch.cat",
     prims.cat.id,
-    prims.pad.id,
-    prims.slice_prim.id,
+    #prims.pad.id,
+    #prims.slice_prim.id,
 }
 torch_compile_cat_ex = TorchCompileExecutor(name="torchcompile_cat", required_ops=required_ops)
 register_executor(torch_compile_cat_ex)

Here's how the forward trace is affected:

  1. phi_a is not saved for backward anymore https://github.com/Lightning-AI/lightning-thunder/blob/3034ef9efad3db43f9e94e3e9013ce853c0b3680/thunder/torch/__init__.py#L1426

Here's how the backward trace is affected:

  1. There's no TorchCompile fusion anymore, all backward fusions are using nvFuser
  2. nvFusion1 was previously interrupted and its output was consumed only by interrupting TorchCompile region, now it returns one of the grads and intermediate consumed by matmul
  3. nvFusion2 now merges the work done by the the TorchCompile0 and nvFusion2 regions before the change

Let's check how this change impacts our current microbenchmark (which should be revised and updated!):

pytest thunder/benchmarks/targets.py -k "test_llama2_qkv_split_rope_7b_train[thunder+nvfuser+torch.compile]"

The results are worse with this change on Llama 2 7B arch: 8.6 ms vs 19.5 ms, and for Pythia 6.9B it's 4.4 ms vs 9.8 ms. Running Pythia 6.9B on a single GPU with or without this change doesn't have an impact on perf though.

parthmannan commented 6 months ago

That's quite interesting.

The major difference, I think, from Feb is that nvFuser now fuses cat since https://github.com/Lightning-AI/lightning-thunder/pull/35.

Do we know if nvFuser enabling cat is just functionally good or performant as well? I believe TorchCompileCat executor was added for concats specifically so using that without cat is probably expected to be worse if the nvFuser support isn't as performant.

I can look into that. But even if thunder_inductor_cat improves perf a little bit, we may be leaving perf on the table as it improves RoPE perf but degrades other aspects of the trace so there should still be similar room for improvement.

parthmannan commented 6 months ago

I also just dived into the trace again to see if the behavior is the same but it has changed. And this is because the Thunder trace itself has changed a little bit. Rolling back to a previous comment (see above for the calculation) -

dGeLU(x) = A + (B * C) / f337

Earlier we had, nvFusion compute B, C separately. TorchCompile compute B*C and then nvFuser compute the final output leading to 4 different fusion regions to compute this.

Now we have, nvFusion computes B, C and BC in a single region and TorchCompile only does A + BC and has the final output. Somehow, with the change in the Thunder trace - we don't see the same breaks in the nvFuser regions and we see only 2 fusion regions now. This is why the performance on Pythia looks decent now. This is quite interesting and I am thinking I have to re-analyze other networks like Phi and Dolly and confirm whether the TorchCompile cat executor is still a perf issue there.

IvanYashchuk commented 6 months ago

Do we know if nvFuser enabling cat is just functionally good or performant as well?

I don't know it and I've created https://github.com/Lightning-AI/lightning-thunder/pull/479 to simplify answering this question.

IvanYashchuk commented 5 months ago

@kiya00, could you please help identify what are the best executor options (nvFuser/Inductor/Apex) separately for the forward and backward of this region on the range of all model configurations from LitGPT? The relevant benchmark in question is this one https://github.com/Lightning-AI/lightning-thunder/blob/d1b016a58a48e5c6282622de488be8c9135dd821/thunder/benchmarks/targets.py#L535

pytest thunder/benchmarks/targets.py -k "test_litgpt_qkv_split_rope" --benchmark-group-by='param:config,param:bs,param:compute_type'

there's also an environment variable to launch more benchmarks https://github.com/Lightning-AI/lightning-thunder/blob/d1b016a58a48e5c6282622de488be8c9135dd821/thunder/benchmarks/targets.py#L54

kiya00 commented 5 months ago

Hi @IvanYashchuk , Here are some microbenchmark results for all LitGPT configurations separately for forward and backward with different executor options (1404 items), but in almost all the cases torch.comple seems better, I'm not very clear about the background of this issue, should this benchmark have better perf using thunder+nvfuser+torch.compile at least on Pythia? I print out the forward/backward trace of test_litgpt_qkv_split_rope[pythia-1.4b-backward-bs1-thunder+nvfuser+torch.compile], they both have only one TorchCompile0 region.

trace of test_litgpt_qkv_split_rope[pythia-1.4b-backward-bs1-thunder+nvfuser+torch.compile] ``` # Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def augmented_forward_fn(qkv, cos, sin): # qkv: "cuda:0 bf16[1, 2048, 6144]" # cos: "cuda:0 bf16[2048, 32]" # sin: "cuda:0 bf16[2048, 32]" [t17, t62, t65] = TorchCompile0(cos, qkv, sin) # t0 = prims.reshape(qkv, (1, 2048, 16, 3, 128)) # t0: "cuda:0 bf16[1, 2048, 16, 3, 128]" # t1 = prims.transpose(t0, (0, 2, 3, 1, 4)) # t1: "cuda:0 bf16[1, 16, 3, 2048, 128]" # (t2, t3, t4) = ltorch.split(t1, (1, 1, 1), 2) # t2 = prims.slice_prim(t1, [0, 0, 0, 0, 0], [1, 16, 1, 2048, 128], [1, 1, 1, 1, 1]) # t2: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t3 = prims.slice_prim(t1, [0, 0, 1, 0, 0], [1, 16, 2, 2048, 128], [1, 1, 1, 1, 1]) # t3: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t4 = prims.slice_prim(t1, [0, 0, 2, 0, 0], [1, 16, 3, 2048, 128], [1, 1, 1, 1, 1]) # t4: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t5 = prims.reshape(t2, (1, 16, 2048, 128)) # t5: "cuda:0 bf16[1, 16, 2048, 128]" # t11 = prims.reshape(t3, (1, 16, 2048, 128)) # t11: "cuda:0 bf16[1, 16, 2048, 128]" # t17 = prims.reshape(t4, (1, 16, 2048, 128)) # t17: "cuda:0 bf16[1, 16, 2048, 128]" # t18 = prims.slice_prim(t5, [0, 0, 0, 0], [1, 16, 2048, 32], [1, 1, 1, 1]) # t18: "cuda:0 bf16[1, 16, 2048, 32]" # t19 = prims.slice_prim(t18, [0, 0, 0, 0], [1, 16, 2048, 16], [1, 1, 1, 1]) # t19: "cuda:0 bf16[1, 16, 2048, 16]" # t20 = prims.slice_prim(t18, [0, 0, 0, 16], [1, 16, 2048, 32], [1, 1, 1, 1]) # t20: "cuda:0 bf16[1, 16, 2048, 16]" # t21 = prims.convert_element_type(t20, dtypes.float32) # t21: "cuda:0 f32[1, 16, 2048, 16]" # t22 = prims.neg(t21) # t22: "cuda:0 f32[1, 16, 2048, 16]" # t23 = prims.convert_element_type(t22, dtypes.bfloat16) # t23: "cuda:0 bf16[1, 16, 2048, 16]" # t24 = prims.cat((t23, t19), -1) # t24: "cuda:0 bf16[1, 16, 2048, 32]" # t25 = prims.broadcast_in_dim(cos, (1, 16, 2048, 32), (2, 3)) # t25: "cuda:0 bf16[1, 16, 2048, 32]" # t26 = prims.convert_element_type(t18, dtypes.float32) # t26: "cuda:0 f32[1, 16, 2048, 32]" # t27 = prims.convert_element_type(t25, dtypes.float32) # t27: "cuda:0 f32[1, 16, 2048, 32]" # t28 = ltorch.mul(t26, t27) # t28: "cuda:0 f32[1, 16, 2048, 32]" # t28 = prims.mul(t26, t27) # t28: "cuda:0 f32[1, 16, 2048, 32]" # t29 = prims.convert_element_type(t28, dtypes.bfloat16) # t29: "cuda:0 bf16[1, 16, 2048, 32]" # t30 = prims.broadcast_in_dim(sin, (1, 16, 2048, 32), (2, 3)) # t30: "cuda:0 bf16[1, 16, 2048, 32]" # t31 = prims.convert_element_type(t24, dtypes.float32) # t31: "cuda:0 f32[1, 16, 2048, 32]" # t32 = prims.convert_element_type(t30, dtypes.float32) # t32: "cuda:0 f32[1, 16, 2048, 32]" # t33 = ltorch.mul(t31, t32) # t33: "cuda:0 f32[1, 16, 2048, 32]" # t33 = prims.mul(t31, t32) # t33: "cuda:0 f32[1, 16, 2048, 32]" # t34 = prims.convert_element_type(t33, dtypes.bfloat16) # t34: "cuda:0 bf16[1, 16, 2048, 32]" # t37 = ltorch.add(t28, t33, alpha=None) # t37: "cuda:0 f32[1, 16, 2048, 32]" # t37 = prims.add(t28, t33) # t37: "cuda:0 f32[1, 16, 2048, 32]" # t38 = prims.convert_element_type(t37, dtypes.bfloat16) # t38: "cuda:0 bf16[1, 16, 2048, 32]" # t39 = prims.slice_prim(t11, [0, 0, 0, 0], [1, 16, 2048, 32], [1, 1, 1, 1]) # t39: "cuda:0 bf16[1, 16, 2048, 32]" # t40 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 16, 2048, 16], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 16, 2048, 16]" # t41 = prims.slice_prim(t39, [0, 0, 0, 16], [1, 16, 2048, 32], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 16, 2048, 16]" # t42 = prims.convert_element_type(t41, dtypes.float32) # t42: "cuda:0 f32[1, 16, 2048, 16]" # t43 = prims.neg(t42) # t43: "cuda:0 f32[1, 16, 2048, 16]" # t44 = prims.convert_element_type(t43, dtypes.bfloat16) # t44: "cuda:0 bf16[1, 16, 2048, 16]" # t46 = prims.cat((t44, t40), -1) # t46: "cuda:0 bf16[1, 16, 2048, 32]" # t48 = prims.convert_element_type(t39, dtypes.float32) # t48: "cuda:0 f32[1, 16, 2048, 32]" # t50 = ltorch.mul(t48, t27) # t50: "cuda:0 f32[1, 16, 2048, 32]" # t50 = prims.mul(t48, t27) # t50: "cuda:0 f32[1, 16, 2048, 32]" # t51 = prims.convert_element_type(t50, dtypes.bfloat16) # t51: "cuda:0 bf16[1, 16, 2048, 32]" # t53 = prims.convert_element_type(t46, dtypes.float32) # t53: "cuda:0 f32[1, 16, 2048, 32]" # t55 = ltorch.mul(t53, t32) # t55: "cuda:0 f32[1, 16, 2048, 32]" # t55 = prims.mul(t53, t32) # t55: "cuda:0 f32[1, 16, 2048, 32]" # t56 = prims.convert_element_type(t55, dtypes.bfloat16) # t56: "cuda:0 bf16[1, 16, 2048, 32]" # t59 = ltorch.add(t50, t55, alpha=None) # t59: "cuda:0 f32[1, 16, 2048, 32]" # t59 = prims.add(t50, t55) # t59: "cuda:0 f32[1, 16, 2048, 32]" # t60 = prims.convert_element_type(t59, dtypes.bfloat16) # t60: "cuda:0 bf16[1, 16, 2048, 32]" # t61 = prims.slice_prim(t5, [0, 0, 0, 32], [1, 16, 2048, 128], [1, 1, 1, 1]) # t61: "cuda:0 bf16[1, 16, 2048, 96]" # t62 = prims.cat((t38, t61), -1) # t62: "cuda:0 bf16[1, 16, 2048, 128]" # t63 = prims.slice_prim(t11, [0, 0, 0, 32], [1, 16, 2048, 128], [1, 1, 1, 1]) # t63: "cuda:0 bf16[1, 16, 2048, 96]" # t65 = prims.cat((t60, t63), -1) # t65: "cuda:0 bf16[1, 16, 2048, 128]" return {'output': (t62, t65, t17), 'flat_args': [qkv, cos, sin], 'flat_output': (t62, t65, t17)}, ((cos, sin), (2,)) # Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, C1, = saved_for_backward clear_mutable_collection(saved_for_backward) del saved_for_backward t6, t7, t8, = cotangents clear_mutable_collection(cotangents) del cotangents cos, sin, = C0 clear_mutable_collection(C0) del C0 i15, = C1 clear_mutable_collection(C1) del C1 [t340] = TorchCompile0(cos, i15, sin, t6, t7, t8) # t25 = prims.broadcast_in_dim(cos, (1, 16, 2048, 32), (2, 3)) # t25: "cuda:0 bf16[1, 16, 2048, 32]" # t27 = prims.convert_element_type(t25, dtypes.float32) # t27: "cuda:0 f32[1, 16, 2048, 32]" # t30 = prims.broadcast_in_dim(sin, (1, 16, 2048, 32), (2, 3)) # t30: "cuda:0 bf16[1, 16, 2048, 32]" # t32 = prims.convert_element_type(t30, dtypes.float32) # t32: "cuda:0 f32[1, 16, 2048, 32]" # t214 = prims.slice_prim(t7, [0, 0, 0, 0], [1, 16, 2048, 32], [1, 1, 1, 1]) # t214: "cuda:0 bf16[1, 16, 2048, 32]" # t215 = prims.slice_prim(t7, [0, 0, 0, 32], [1, 16, 2048, 128], [1, 1, 1, 1]) # t215: "cuda:0 bf16[1, 16, 2048, 96]" # t216 = prims.pad(t215, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t216: "cuda:0 bf16[1, 16, 2048, 128]" # t218 = prims.slice_prim(t6, [0, 0, 0, 0], [1, 16, 2048, 32], [1, 1, 1, 1]) # t218: "cuda:0 bf16[1, 16, 2048, 32]" # t219 = prims.slice_prim(t6, [0, 0, 0, 32], [1, 16, 2048, 128], [1, 1, 1, 1]) # t219: "cuda:0 bf16[1, 16, 2048, 96]" # t220 = prims.pad(t219, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t220: "cuda:0 bf16[1, 16, 2048, 128]" # t221 = prims.convert_element_type(t214, dtypes.float32) # t221: "cuda:0 f32[1, 16, 2048, 32]" # t225 = ltorch.mul(t32, t221) # t225: "cuda:0 f32[1, 16, 2048, 32]" # t225 = prims.mul(t32, t221) # t225: "cuda:0 f32[1, 16, 2048, 32]" # t228 = prims.convert_element_type(t225, dtypes.bfloat16) # t228: "cuda:0 bf16[1, 16, 2048, 32]" # t233 = ltorch.mul(t27, t221) # t233: "cuda:0 f32[1, 16, 2048, 32]" # t233 = prims.mul(t27, t221) # t233: "cuda:0 f32[1, 16, 2048, 32]" # t236 = prims.convert_element_type(t233, dtypes.bfloat16) # t236: "cuda:0 bf16[1, 16, 2048, 32]" # t241 = prims.slice_prim(t228, [0, 0, 0, 0], [1, 16, 2048, 16], [1, 1, 1, 1]) # t241: "cuda:0 bf16[1, 16, 2048, 16]" # t242 = prims.slice_prim(t228, [0, 0, 0, 16], [1, 16, 2048, 32], [1, 1, 1, 1]) # t242: "cuda:0 bf16[1, 16, 2048, 16]" # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: "cuda:0 f32[1, 16, 2048, 16]" # t244 = ltorch.neg(t243) # t244: "cuda:0 f32[1, 16, 2048, 16]" # t244 = prims.neg(t243) # t244: "cuda:0 f32[1, 16, 2048, 16]" # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: "cuda:0 bf16[1, 16, 2048, 16]" # t246 = prims.pad(t245, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (16, 0, 0))) # t246: "cuda:0 bf16[1, 16, 2048, 32]" # t248 = prims.convert_element_type(t246, dtypes.float32) # t248: "cuda:0 f32[1, 16, 2048, 32]" # t249 = prims.add(t233, t248) # t249: "cuda:0 f32[1, 16, 2048, 32]" # t250 = prims.convert_element_type(t249, dtypes.bfloat16) # t250: "cuda:0 bf16[1, 16, 2048, 32]" # t251 = prims.pad(t242, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 16, 0))) # t251: "cuda:0 bf16[1, 16, 2048, 32]" # t253 = prims.convert_element_type(t251, dtypes.float32) # t253: "cuda:0 f32[1, 16, 2048, 32]" # t254 = prims.add(t249, t253) # t254: "cuda:0 f32[1, 16, 2048, 32]" # t255 = prims.convert_element_type(t254, dtypes.bfloat16) # t255: "cuda:0 bf16[1, 16, 2048, 32]" # t256 = prims.pad(t255, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 96, 0))) # t256: "cuda:0 bf16[1, 16, 2048, 128]" # t257 = prims.convert_element_type(t216, dtypes.float32) # t257: "cuda:0 f32[1, 16, 2048, 128]" # t258 = prims.convert_element_type(t256, dtypes.float32) # t258: "cuda:0 f32[1, 16, 2048, 128]" # t259 = prims.add(t257, t258) # t259: "cuda:0 f32[1, 16, 2048, 128]" # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 16, 2048, 128]" # t261 = prims.convert_element_type(t218, dtypes.float32) # t261: "cuda:0 f32[1, 16, 2048, 32]" # t265 = ltorch.mul(t32, t261) # t265: "cuda:0 f32[1, 16, 2048, 32]" # t265 = prims.mul(t32, t261) # t265: "cuda:0 f32[1, 16, 2048, 32]" # t268 = prims.convert_element_type(t265, dtypes.bfloat16) # t268: "cuda:0 bf16[1, 16, 2048, 32]" # t277 = ltorch.mul(t27, t261) # t277: "cuda:0 f32[1, 16, 2048, 32]" # t277 = prims.mul(t27, t261) # t277: "cuda:0 f32[1, 16, 2048, 32]" # t280 = prims.convert_element_type(t277, dtypes.bfloat16) # t280: "cuda:0 bf16[1, 16, 2048, 32]" # t289 = prims.slice_prim(t268, [0, 0, 0, 0], [1, 16, 2048, 16], [1, 1, 1, 1]) # t289: "cuda:0 bf16[1, 16, 2048, 16]" # t290 = prims.slice_prim(t268, [0, 0, 0, 16], [1, 16, 2048, 32], [1, 1, 1, 1]) # t290: "cuda:0 bf16[1, 16, 2048, 16]" # t291 = prims.convert_element_type(t289, dtypes.float32) # t291: "cuda:0 f32[1, 16, 2048, 16]" # t292 = ltorch.neg(t291) # t292: "cuda:0 f32[1, 16, 2048, 16]" # t292 = prims.neg(t291) # t292: "cuda:0 f32[1, 16, 2048, 16]" # t293 = prims.convert_element_type(t292, dtypes.bfloat16) # t293: "cuda:0 bf16[1, 16, 2048, 16]" # t294 = prims.pad(t293, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (16, 0, 0))) # t294: "cuda:0 bf16[1, 16, 2048, 32]" # t296 = prims.convert_element_type(t294, dtypes.float32) # t296: "cuda:0 f32[1, 16, 2048, 32]" # t297 = prims.add(t277, t296) # t297: "cuda:0 f32[1, 16, 2048, 32]" # t298 = prims.convert_element_type(t297, dtypes.bfloat16) # t298: "cuda:0 bf16[1, 16, 2048, 32]" # t299 = prims.pad(t290, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 16, 0))) # t299: "cuda:0 bf16[1, 16, 2048, 32]" # t301 = prims.convert_element_type(t299, dtypes.float32) # t301: "cuda:0 f32[1, 16, 2048, 32]" # t302 = prims.add(t297, t301) # t302: "cuda:0 f32[1, 16, 2048, 32]" # t303 = prims.convert_element_type(t302, dtypes.bfloat16) # t303: "cuda:0 bf16[1, 16, 2048, 32]" # t304 = prims.pad(t303, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 96, 0))) # t304: "cuda:0 bf16[1, 16, 2048, 128]" # t305 = prims.convert_element_type(t220, dtypes.float32) # t305: "cuda:0 f32[1, 16, 2048, 128]" # t306 = prims.convert_element_type(t304, dtypes.float32) # t306: "cuda:0 f32[1, 16, 2048, 128]" # t307 = prims.add(t305, t306) # t307: "cuda:0 f32[1, 16, 2048, 128]" # t308 = prims.convert_element_type(t307, dtypes.bfloat16) # t308: "cuda:0 bf16[1, 16, 2048, 128]" # t313 = prims.reshape(t8, (1, 16, 1, 2048, 128)) # t313: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t318 = prims.reshape(t260, (1, 16, 1, 2048, 128)) # t318: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t323 = prims.reshape(t308, (1, 16, 1, 2048, 128)) # t323: "cuda:0 bf16[1, 16, 1, 2048, 128]" # t328 = ltorch.cat((t323, t318, t313), i15) # t328: "cuda:0 bf16[1, 16, 3, 2048, 128]" # t328 = prims.cat((t323, t318, t313), i15) # t328: "cuda:0 bf16[1, 16, 3, 2048, 128]" # t334 = prims.transpose(t328, (0, 3, 1, 2, 4)) # t334: "cuda:0 bf16[1, 2048, 16, 3, 128]" # t340 = prims.reshape(t334, (1, 2048, 6144)) # t340: "cuda:0 bf16[1, 2048, 6144]" del cos, i15, sin, t6, t7, t8 return (t340, None, None) ```

litgpt_screen.log

(the trace is on thunder: 69e80f0a094376576a39306f62b9c510138e41fa, the pref log is a few days old, on thunder: d1d581c401fb201d2f181c66bdc4281cf616c935)

IvanYashchuk commented 5 months ago

I'm not very clear about the background of this issue, should this benchmark have better perf using thunder+nvfuser+torch.compile at least on Pythia?

Thunder (and thunder+nvfuser+torch.compile) should have better performance for all cases. The purpose of these benchmarks is to evaluate the current situation and identify what needs to be done to improve performance for worst performing cases. Besides the logs it would be useful to have a script that analyzes the json results from pytest-benchmark and creates a summary:

Include summary of any other important information that is the json files. A table could be useful, something like:

Metric Batch Size 1 Batch Size 2
Top Executor Executor A Executor B
Percentage of Configs Best for Executor
- Executor A 60% 50%
- Executor B 30% 40%
- Executor C 10% 10%
Gap Between Top Executor and Thunder
- Max Gap 15 ms 20 ms
- Min Gap 1 ms 2 ms
- Mean Gap 8 ms 10 ms
- Median Gap 7 ms 9 ms

If the batchsize is increased how do the results change? Maybe the overheads are too large for Thunder? How does the patch from https://github.com/Lightning-AI/lightning-thunder/issues/256#issuecomment-2136096575 affects numbers of thunder+nvfuser+torch.compile executor?

Maybe it's possible to get pure CUDA kernel times with a timer from nvFuser https://github.com/NVIDIA/Fuser/blob/18750278f9f20a817808dc1c63c0fb6962d37c9c/benchmarks/python/core.py#L209-L229.

kiya00 commented 5 months ago

issue256_analysis.xlsx Here is some initial information for reference (based on container 0614) when looking at one test case test_litgpt_qkv_split_rope[pythia-1.4b-backward-bs1-thunder+nvfuser+torch.compile], the trace has one single TorchCompile0, but the mean time of thunder+nvfuser+torch.compile(148.1826) is much worse than torch.compile(108.1542),

kiya00 commented 5 months ago

This is quite interesting and I am thinking I have to re-analyze other networks like Phi and Dolly and confirm whether the TorchCompile cat executor is still a perf issue there.

Hi @parthmannan , checking the performance today for the original problem with commands(on H100, container0626)

torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=8 --model_name=pythia-6.9b
torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=8 --model_name=pythia-6.9b
python thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --micro_batch_size=1 --model_name=pythia-6.9b
python thunder/benchmarks/benchmark_litgpt.py --compile=thunder --micro_batch_size=1 --model_name=pythia-6.9b
pythia-6.9b zero2 single gpu
thunder_inductor_cat 234.84 ms 208.45 ms
thunder 233.02 ms 210.97 ms
phi-2 zero2 single gpu
thunder_inductor_cat 114.54 ms 104.35 ms
thunder 113.44 ms 105.14 ms
dolly-v2-7b zero2 single gpu
thunder_inductor_cat 233.81 ms 208.45 ms
thunder 233.26 ms 211.03 ms

The performance of thunder_inductor_cat exe seems to be decent now. I'll dig into the reason why the RoPE microbenchmark is not what we expected, but for the original problem in this issue, do we want more analysis on that?

cc: @IvanYashchuk

IvanYashchuk commented 5 months ago

It's great that "thunder_inductor_cat" is now better for Phi-2 and Dolly, thank you for rerunning the benchmarks! I'm inclined to close this particular issue and start a new one specifically for RoPE microbenchmark performance. We need to understand better how impactful improvements for that microbenchmark for full network runs.

parthmannan commented 5 months ago

Given this data, we can definitely close this issue. Thanks so much @kiya00 for re-running this and the analysis. Just curious, do we know what changed in the partitioning logic that the performance issues are gone now?

IvanYashchuk commented 4 months ago

No, I don't know what was changed there.