Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

Thunder + Inductor gives OOM for stablecode-completion-alpha-3b model from LitGPT #246

Open mpatel31415 opened 4 months ago

mpatel31415 commented 4 months ago

🐛 Bug

With newest version of Docker image (tested on 2024-04-22 ) training with thunder.jit with additional inductor executor gives OOM error.

To Reproduce

Before each testing each compilation method I restarted the container:

mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_ADDRESS:5005/dl/pytorch/update-scripts:pjnl-20240422

Thunder inductor

Eager

Thunder

Inductor

Expected behavior

If we can run model in Eager mode, with Thunder + Pytorch default compilation we should be able to use it for Thunder + Inductor as well.

Environment

As in the Docker image. This results come from single H100.

nvidia-smi output: image

mruberry commented 4 months ago

triage review:

mpatel31415 commented 4 months ago

When running:

python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \
--model_name stablecode-completion-alpha-3b \
--compile  "thunder_inductor"

I still get OOM error for tag pjnl-20240427.

Memory usage:

IvanYashchuk commented 4 months ago

The new command to reproduce the problem is

python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --compile  "thunder_inductor_cat"

because "inductor" was renamed to "inductor_cat" in https://github.com/Lightning-AI/lightning-thunder/pull/140.

IvanYashchuk commented 4 months ago

Thunder uses a lot more memory than both torch.compile or eager. That's something that should be fixed.

The forward trace has the following weirdness:

  [t137, t142, t146, t158, t228, t232, t244] = nvFusion2(t129, t152, t155, t238, t241, t4, t90)
    # t130 = prims.convert_element_type(t129, dtypes.float32)  # t130: "cuda:0 f32[1, 16384, 2560]"
    # t131 = prims.convert_element_type(t90, dtypes.float32)  # t131: "cuda:0 f32[1, 16384, 2560]"
    # t132 = prims.add(t130, t131)  # t132: "cuda:0 f32[1, 16384, 2560]"
    # t135 = prims.convert_element_type(t4, dtypes.float32)  # t135: "cuda:0 f32[1, 16384, 2560]"
    # t136 = prims.add(t132, t135)  # t136: "cuda:0 f32[1, 16384, 2560]"
    # t137 = prims.convert_element_type(t136, dtypes.bfloat16)  # t137: "cuda:0 bf16[1, 16384, 2560]"
    # (t141, t142) = prims.var_mean(t136, (2,), correction=0) ## <-- What's going on here?
    # (t227, t228) = prims.var_mean(t136, (2,), correction=0) ## <-- What's going on here?
    ...

two var_mean calls for the same input with both t142 and t228 saved to outputs and this pattern is repeated in every nvFusion region. And it's the same both with and without the "inductor_cat" executor.

@riccardofelluga, could you please investigate what's wrong with Thunder on this model and fix the problem?

I suggest starting with PyTorch memory profiler to spot the differences in memory allocations between PyTorch eager and Thunder.

The second var_mean in the nvFuser regions should have been eliminated with CSE pass, probably it doesn't work because comparison and hashing of BoundSymbols with keyword arguments is not implemented, see https://github.com/Lightning-AI/lightning-thunder/issues/397. Where does the second var_mean come from at all?

riccardofelluga commented 4 months ago

Quick update, we probably have multiple issues with memory, however I think with this issue the focus should be on the peak memory usage. For comparison here is the memory profiles for eager:

mem_prof_annotated_eager

and here the one for thunder.jit:

mem_prof_annotated_jit

They both are pretty similar however there are a couple notable differences. In the Thunder trace the data from the log_softmax_backward is kept around till the end, whereas eager manages to deallocate it almost as soon as possible. This might be an issue, however it's not contributing too much to the peak of allocation.

The second notable difference is the presence of the nvFuser outputs, that for obvious reasons are not present in eager.

A thing that I've noticed by looking at the traces is that, when we have an nvFuser region next to a torch.compile one, the intermediates stick around for long and they are all pretty big in memory. This is accentuated when the intermediates are passed from forward to backward passes like t123 and t128 both of which are consumed by torch.compile in the backward pass.

Will update with further details.

riccardofelluga commented 3 months ago

This issue seems to be another facet of #256 and #446. After further investigation it seems that this extra memory usage comes also from splitting torch.nn.functional.gelu between the TorchCompile and nvFuser executors. In particular, as of now, prims.erf and prims.div are fused by nvFuser and the intermediate tensors that are passed from this fusion to the next TorchCompile fusion are pretty big on the order of 600MB for the 1-layer example I showed in the previous comment.

A solution to lower the memory usage is to let TorchCompile fuse those operators by adding them in https://github.com/Lightning-AI/lightning-thunder/blob/82185e3a55d5b3f0bea8a7366d74a275dbe34acd/thunder/executors/torch_compile.py#L209-L222

With this change it is possible to lower the memory usage of the example from 9.548GB to 9.129GB. However such a change is not enough to avoid OOM in the reported issue.

@IvanYashchuk how do you feel about the compromise of adding those operators to the TorchCompile? It is true that it does not solve this issue but there are some memory savings, even tho at the expenses of a little of performance. (and we would need to do more perf testing with the benchmarks to see that there are no other penalties)

Regarding the nvFuser issue highlighted in a previous comment, I am currently looking into it and it's tracked by #397. I am checking if that might be the issue here so I will update this comment as soon as I have more information.

Just to be clear I am not ruling that out, but I think this intermediate tensors I discovered here are also partially to blame.

riccardofelluga commented 3 months ago

Update on the CSE issue, unfortunately it didn't help with memory usage :(

IvanYashchuk commented 3 months ago

@IvanYashchuk how do you feel about the compromise of adding those operators to the TorchCompile?

Another alternative is to do the opposite: restricting the number of supported operators in this executor forcing TorchCompile to be used only for forward and nvFuser for backward (https://github.com/Lightning-AI/lightning-thunder/issues/256#issuecomment-2136096575). Both of these choices should be benchmarked.

riccardofelluga commented 2 months ago

Update on this issue, as of today thunder runs stablecode-completion-alpha-3b with compile option thunder_inductor successfully. However, with thunder_inductor_cat it OOMs.

Stats from python thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --compile thunder_inductor

Model name: stablecode-completion-alpha-3b
Seq Length: 16384
Micro BS: 1
Global BS: 1
Number of Layers: 32
Number of parameters: 2.77B
Distributed Mode: none
Compiler: thunder_inductor
Average iter time: 1293.71 ms
Memory used: 71.27 GB
Tokens/s: 12663.37
Tokens/s/GPU: 12663.37
TFLOP/s: 404.73

Average iteration time 1293.71ms is on H100 80GB HBM3

IvanYashchuk commented 3 weeks ago

One config flag different from the Llama models that affects the memory usage here is parallel_residual (for Llama it's False for StableCode it's set to True). When it's True and the "inductor_cat" executor is used one more tensor is saved for backward for each transformer block. The size of this tensor is (batch_size, block_size, 4 * n_embd) and dtype is fp32, so it's about 670MB more memory per transformer block with batch_size=1. When the "inductor_cat" executor is not used this intermediate is not saved to global memory but recomputed due to the rematerialization pass.

One fix that resolves the problem by constraining TorchCompile region to fuse only when two or more cat operations are present in the candidate fusion group:

diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py
index 204bc204..db268b7d 100644
--- a/thunder/executors/torch_compile.py
+++ b/thunder/executors/torch_compile.py
@@ -166,7 +166,7 @@ class TorchCompileExecutor(FusionExecutor):
                     continue

             # TODO: this could use `get_fuel()` like nvfuserex does
-            if self.required_ops is None or any(bsym.sym.id in self.required_ops for bsym in bsyms):
+            if self.required_ops is None or len(tuple(bsym for bsym in bsyms if bsym.sym.id in self.required_ops)) > 1:
                 region = Region(producers, consumers, bsyms)
                 fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter)
                 fusion_counter += 1

There's a cat op that appears in backward of split and with the patch above nvFuser is handling that.

Another fix is:

diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py
index 204bc204..8f41ee16 100644
--- a/thunder/executors/torch_compile.py
+++ b/thunder/executors/torch_compile.py
@@ -221,6 +221,8 @@ supported_ops = {
     prims.add.id,
     prims.broadcast_in_dim.id,
     prims.cat.id,
+    prims.div.id,
+    prims.erf.id,
     prims.convert_element_type.id,
     prims.full.id,
     prims.mul.id,

with this gelu is fused into the TorchCompile region.

The first solution results in a slightly faster code than the second one on H100 and I checked that the first solution doesn't have perf impact on Llama 3 8B (which uses parallel_residual=False).

It's still not enough for a single GPU run of this model but at least the problem is occurring in the loss function and sharding may help:

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 3479, in cross_entropy
    return torch._C._nn.cross_entropy_loss(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB. GPU 0 has a total capacity of 79.10 GiB of which 132.00 MiB is free. Including non-PyTorch memory, this process has 78.96 GiB memory in use. Of the allocated memory 76.84 GiB is allocated by PyTorch, and 1.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

After the change, Thunder still saves more tensors for backward compared to Inductor because Inductor doesn't save the result of gelu (input to linear layers) for backward; it's recomputed instead.

Setting nv_enable_bookend=False also helps with memory usage.

IvanYashchuk commented 3 weeks ago

For a microbenchmark, the first solution of restricting TorchCompile to be applied only to regions with more than 1 cat op results in 1.5x-2x regressions in backward.

Click for results ```py ---------------------------------------------------------------------- benchmark 'config=Llama-2-13b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ----------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-13b-hf-backward-bs1-thunder] (0002_a59b4ef) 276.1725 (1.0) 611.1249 (1.0) 294.0876 (1.0) 44.3176 (1.0) 283.7569 (1.0) 4.2992 (1.0) 103;162 3.4003 (1.0) 1816 2 test_litgpt_qkv_split_rope[Llama-2-13b-hf-backward-bs1-thunder] (0003_a59b4ef) 586.4920 (2.12) 1,419.2839 (2.32) 615.2648 (2.09) 63.8915 (1.44) 598.6331 (2.11) 8.0809 (1.88) 112;120 1.6253 (0.48) 1702 1 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-2-13b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests --------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-13b-hf-forward-bs1-thunder] (0003_a59b4ef) 133.0868 (1.0) 243.2520 (1.05) 140.9440 (1.00) 10.8660 (1.0) 137.6764 (1.00) 3.0040 (1.13) 60;80 7.0950 (1.00) 743 10 test_litgpt_qkv_split_rope[Llama-2-13b-hf-forward-bs1-thunder] (0002_a59b4ef) 133.5939 (1.00) 231.8330 (1.0) 140.6384 (1.0) 11.3704 (1.05) 136.9943 (1.0) 2.6605 (1.0) 69;91 7.1104 (1.0) 753 10 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-2-70b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ----------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-70b-hf-backward-bs1-thunder] (0002_a59b4ef) 451.5285 (1.0) 805.6330 (1.0) 469.3483 (1.0) 45.2149 (1.0) 459.5364 (1.0) 4.7270 (1.0) 63;74 2.1306 (1.0) 1107 2 test_litgpt_qkv_split_rope[Llama-2-70b-hf-backward-bs1-thunder] (0003_a59b4ef) 795.6610 (1.76) 1,097.3061 (1.36) 822.7128 (1.75) 59.4276 (1.31) 807.6350 (1.76) 8.3109 (1.76) 79;84 1.2155 (0.57) 1258 1 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-2-70b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests --------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-70b-hf-forward-bs1-thunder] (0002_a59b4ef) 137.2022 (1.0) 210.7677 (1.02) 145.0044 (1.01) 10.4159 (1.13) 141.8949 (1.00) 2.4188 (1.0) 60;81 6.8963 (0.99) 724 10 test_litgpt_qkv_split_rope[Llama-2-70b-hf-forward-bs1-thunder] (0003_a59b4ef) 137.5576 (1.00) 207.0066 (1.0) 144.1428 (1.0) 9.2299 (1.0) 141.3867 (1.0) 2.7606 (1.14) 58;76 6.9376 (1.0) 730 10 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-2-7b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests --------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-7b-hf-backward-bs1-thunder] (0002_a59b4ef) 213.9550 (1.0) 443.1440 (1.0) 228.0024 (1.0) 29.6249 (1.08) 221.3846 (1.0) 4.2797 (1.0) 87;113 4.3859 (1.0) 1551 3 test_litgpt_qkv_split_rope[Llama-2-7b-hf-backward-bs1-thunder] (0003_a59b4ef) 430.9175 (2.01) 646.1220 (1.46) 442.5603 (1.94) 27.3788 (1.0) 436.4770 (1.97) 4.6765 (1.09) 54;61 2.2596 (0.52) 1158 2 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------- benchmark 'config=Llama-2-7b-hf bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ---------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-2-7b-hf-forward-bs1-thunder] (0003_a59b4ef) 133.2749 (1.0) 203.8856 (1.0) 140.2871 (1.0) 9.9329 (1.0) 137.2565 (1.0) 2.9015 (1.05) 58;80 7.1282 (1.0) 741 10 test_litgpt_qkv_split_rope[Llama-2-7b-hf-forward-bs1-thunder] (0002_a59b4ef) 134.4965 (1.01) 326.2909 (1.60) 141.3362 (1.01) 12.2131 (1.23) 137.9538 (1.01) 2.7763 (1.0) 58;79 7.0753 (0.99) 750 10 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ------------------------------------------------------------------------- benchmark 'config=Llama-3-70B bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ------------------------------------------------------------------------ Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-3-70B-backward-bs1-thunder] (0002_a59b4ef) 906.6451 (1.0) 1,587.5280 (1.0) 944.8423 (1.0) 92.2192 (1.40) 922.6260 (1.0) 9.6493 (1.30) 63;104 1,058.3777 (1.0) 1104 1 test_litgpt_qkv_split_rope[Llama-3-70B-backward-bs1-thunder] (0003_a59b4ef) 1,428.3240 (1.58) 2,078.5250 (1.31) 1,455.9627 (1.54) 65.7737 (1.0) 1,440.2701 (1.56) 7.4503 (1.0) 41;52 686.8308 (0.65) 701 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------- benchmark 'config=Llama-3-70B bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ---------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-3-70B-forward-bs1-thunder] (0002_a59b4ef) 288.6705 (1.0) 521.9895 (1.00) 303.7596 (1.0) 37.5193 (1.05) 293.7200 (1.0) 3.6995 (1.0) 99;179 3.2921 (1.0) 1740 2 test_litgpt_qkv_split_rope[Llama-3-70B-forward-bs1-thunder] (0003_a59b4ef) 289.3669 (1.00) 519.6350 (1.0) 304.2627 (1.00) 35.7109 (1.0) 294.7892 (1.00) 3.7542 (1.01) 96;187 3.2866 (1.00) 1728 2 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ----------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-3-8B-backward-bs1-thunder] (0002_a59b4ef) 431.5650 (1.0) 768.7110 (1.0) 449.1223 (1.0) 42.1622 (1.0) 439.3940 (1.0) 5.5592 (1.0) 66;88 2.2266 (1.0) 1159 2 test_litgpt_qkv_split_rope[Llama-3-8B-backward-bs1-thunder] (0003_a59b4ef) 684.3591 (1.59) 1,051.2539 (1.37) 712.5919 (1.59) 58.5630 (1.39) 696.6521 (1.59) 8.0363 (1.45) 96;120 1.4033 (0.63) 1465 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------- benchmark 'config=Llama-3-8B bs=1 compute_type=UNSERIALIZABLE[]': 2 tests --------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Llama-3-8B-forward-bs1-thunder] (0002_a59b4ef) 138.3921 (1.0) 205.6637 (1.0) 144.7814 (1.0) 9.0431 (1.0) 142.1434 (1.0) 2.7225 (1.0) 57;73 6.9070 (1.0) 724 10 test_litgpt_qkv_split_rope[Llama-3-8B-forward-bs1-thunder] (0003_a59b4ef) 138.9457 (1.00) 210.6213 (1.02) 147.0773 (1.02) 11.2030 (1.24) 142.6780 (1.00) 3.6106 (1.33) 97;130 6.7991 (0.98) 712 10 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------- benchmark 'config=Mistral-7B-v0.1 bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ---------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[Mistral-7B-v0.1-backward-bs1-thunder] (0002_a59b4ef) 232.4861 (1.0) 656.4187 (1.0) 249.5766 (1.0) 35.5329 (1.09) 240.5972 (1.0) 5.8504 (1.34) 83;118 4.0068 (1.0) 1434 3 test_litgpt_qkv_split_rope[Mistral-7B-v0.1-backward-bs1-thunder] (0003_a59b4ef) 355.4726 (1.53) 1,089.9450 (1.66) 368.8290 (1.48) 32.7024 (1.0) 362.0135 (1.50) 4.3810 (1.0) 69;91 2.7113 (0.68) 1411 2 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------- benchmark 'config=Mistral-7B-v0.1 bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ---------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_litgpt_qkv_split_rope[Mistral-7B-v0.1-forward-bs1-thunder] (0002_a59b4ef) 133.9331 (1.0) 206.4989 (1.0) 140.4935 (1.0) 9.7935 (1.0) 137.3357 (1.0) 2.5529 (1.0) 68;96 7.1178 (1.0) 745 10 test_litgpt_qkv_split_rope[Mistral-7B-v0.1-forward-bs1-thunder] (0003_a59b4ef) 135.6643 (1.01) 211.7902 (1.03) 142.4280 (1.01) 11.0583 (1.13) 138.9762 (1.01) 2.9552 (1.16) 65;79 7.0211 (0.99) 738 10 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ ---------------------------------------------------------------------- benchmark 'config=phi-2 bs=1 compute_type=UNSERIALIZABLE[]': 2 tests --------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[phi-2-backward-bs1-thunder] (0002_a59b4ef) 144.9238 (1.0) 224.8427 (1.0) 155.5764 (1.0) 10.0039 (1.0) 154.1878 (1.0) 2.0025 (1.0) 41;143 6.4277 (1.0) 685 10 test_litgpt_qkv_split_rope[phi-2-backward-bs1-thunder] (0003_a59b4ef) 244.9197 (1.69) 498.6986 (2.22) 263.4985 (1.69) 20.8786 (2.09) 258.9605 (1.68) 8.3783 (4.18) 73;77 3.7951 (0.59) 1356 3 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------- benchmark 'config=phi-2 bs=1 compute_type=UNSERIALIZABLE[]': 2 tests ---------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_litgpt_qkv_split_rope[phi-2-forward-bs1-thunder] (0002_a59b4ef) 132.0146 (1.0) 215.1885 (1.05) 138.9197 (1.0) 10.2082 (1.0) 135.7189 (1.0) 2.6002 (1.0) 67;86 7.1984 (1.0) 763 10 test_litgpt_qkv_split_rope[phi-2-forward-bs1-thunder] (0003_a59b4ef) 132.6825 (1.01) 205.1980 (1.0) 140.7300 (1.01) 11.0252 (1.08) 136.8729 (1.01) 3.1684 (1.22) 81;100 7.1058 (0.99) 763 10 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ```

It's better to allow GELU to be fused into a TorchCompile region together with RoPE to not regress on Llama models.

IvanYashchuk commented 3 weeks ago

https://github.com/Lightning-AI/lightning-thunder/pull/1019 lets running this model with FSDP. More memory savings are needed to run this model on a single GPU.