Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Thunder's torch.compile executor may have a different performance with torch.compile #711

Open kiya00 opened 2 months ago

kiya00 commented 2 months ago

Note: the issue being opened here is more to let people know the existence of the difference than to require a fix.

When analyzing the microbenchmark performance of RoPE, Thunder's torch.compile executor(trace with a single TorchCompile0 fusion) sometimes performs worse than torch.compile.

Take tiny-llama-1.1b as an example: pytest thunder/benchmarks/targets.py -k "tiny-llama-1.1b-forward-bs2-thunder+nvfuser+torch.compile] or tiny-llama-1.1b-forward-bs2-torch.compile]"

------------------------------------------------------------------------------------------------------------------------ benchmark: 2 tests -----------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                              Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_qkv_split_rope[tiny-llama-1.1b-forward-bs2-torch.compile]                      43.4808 (1.0)       70.7150 (1.0)       45.8740 (1.0)       3.9031 (1.0)       44.7072 (1.0)      1.1832 (1.0)       144;203       21.7988 (1.0)        2300          10
test_litgpt_qkv_split_rope[tiny-llama-1.1b-forward-bs2-thunder+nvfuser+torch.compile]     222.3384 (5.11)     345.1165 (4.88)     229.5844 (5.00)     18.0420 (4.62)     224.9852 (5.03)     1.9018 (1.61)      126;210        4.3557 (0.20)       2253           2
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

The trace of thunder:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(qkv, cos, sin):
  # qkv: "cuda:0 bf16[2, 2048, 2560]"
  # cos: "cuda:0 bf16[2048, 64]"
  # sin: "cuda:0 bf16[2048, 64]"
  [t24, t34, t39, t69, t72] = TorchCompile0(cos, qkv, sin)
    # t0 = prims.reshape(qkv, (2, 2048, 4, 10, 64))  # t0: "cuda:0 bf16[2, 2048, 4, 10, 64]"
    # t1 = prims.transpose(t0, (0, 2, 3, 1, 4))  # t1: "cuda:0 bf16[2, 4, 10, 2048, 64]"
    # (t2, t3, t4) = ltorch.split(t1, (8, 1, 1), 2)
      # t2 = prims.slice_prim(t1, [0, 0, 0, 0, 0], [2, 4, 8, 2048, 64], [1, 1, 1, 1, 1])  # t2: "cuda:0 bf16[2, 4, 8, 2048, 64]"
      # t3 = prims.slice_prim(t1, [0, 0, 8, 0, 0], [2, 4, 9, 2048, 64], [1, 1, 1, 1, 1])  # t3: "cuda:0 bf16[2, 4, 1, 2048, 64]"
      # t4 = prims.slice_prim(t1, [0, 0, 9, 0, 0], [2, 4, 10, 2048, 64], [1, 1, 1, 1, 1])  # t4: "cuda:0 bf16[2, 4, 1, 2048, 64]"
    # t5 = prims.broadcast_in_dim(t3, (2, 4, 8, 2048, 64), (0, 1, 2, 3, 4))  # t5: "cuda:0 bf16[2, 4, 8, 2048, 64]"
    # t11 = prims.broadcast_in_dim(t4, (2, 4, 8, 2048, 64), (0, 1, 2, 3, 4))  # t11: "cuda:0 bf16[2, 4, 8, 2048, 64]"
    # t12 = prims.reshape(t2, (2, 32, 2048, 64))  # t12: "cuda:0 bf16[2, 32, 2048, 64]"
    # t18 = prims.reshape(t5, (2, 32, 2048, 64))  # t18: "cuda:0 bf16[2, 32, 2048, 64]"
    # t24 = prims.reshape(t11, (2, 32, 2048, 64))  # t24: "cuda:0 bf16[2, 32, 2048, 64]"
    # t25 = prims.slice_prim(t12, [0, 0, 0, 0], [2, 32, 2048, 64], [1, 1, 1, 1])  # t25: "cuda:0 bf16[2, 32, 2048, 64]"
    # t26 = prims.slice_prim(t25, [0, 0, 0, 0], [2, 32, 2048, 32], [1, 1, 1, 1])  # t26: "cuda:0 bf16[2, 32, 2048, 32]"
    # t27 = prims.slice_prim(t25, [0, 0, 0, 32], [2, 32, 2048, 64], [1, 1, 1, 1])  # t27: "cuda:0 bf16[2, 32, 2048, 32]"
    # t28 = prims.convert_element_type(t27, dtypes.thunder.dtypes.float32)  # t28: "cuda:0 f32[2, 32, 2048, 32]"
    # t29 = prims.neg(t28)  # t29: "cuda:0 f32[2, 32, 2048, 32]"
    # t30 = prims.convert_element_type(t29, dtypes.thunder.dtypes.bfloat16)  # t30: "cuda:0 bf16[2, 32, 2048, 32]"
    # t31 = prims.cat((t30, t26), -1)  # t31: "cuda:0 bf16[2, 32, 2048, 64]"
    # t32 = prims.broadcast_in_dim(cos, (2, 32, 2048, 64), (2, 3))  # t32: "cuda:0 bf16[2, 32, 2048, 64]"
    # t33 = prims.convert_element_type(t25, dtypes.thunder.dtypes.float32)  # t33: "cuda:0 f32[2, 32, 2048, 64]"
    # t34 = prims.convert_element_type(t32, dtypes.thunder.dtypes.float32)  # t34: "cuda:0 f32[2, 32, 2048, 64]"
    # t35 = ltorch.mul(t33, t34)  # t35: "cuda:0 f32[2, 32, 2048, 64]"
      # t35 = prims.mul(t33, t34)  # t35: "cuda:0 f32[2, 32, 2048, 64]"
    # t36 = prims.convert_element_type(t35, dtypes.thunder.dtypes.bfloat16)  # t36: "cuda:0 bf16[2, 32, 2048, 64]"
    # t37 = prims.broadcast_in_dim(sin, (2, 32, 2048, 64), (2, 3))  # t37: "cuda:0 bf16[2, 32, 2048, 64]"
    # t38 = prims.convert_element_type(t31, dtypes.thunder.dtypes.float32)  # t38: "cuda:0 f32[2, 32, 2048, 64]"
    # t39 = prims.convert_element_type(t37, dtypes.thunder.dtypes.float32)  # t39: "cuda:0 f32[2, 32, 2048, 64]"
    # t40 = ltorch.mul(t38, t39)  # t40: "cuda:0 f32[2, 32, 2048, 64]"
      # t40 = prims.mul(t38, t39)  # t40: "cuda:0 f32[2, 32, 2048, 64]"
    # t41 = prims.convert_element_type(t40, dtypes.thunder.dtypes.bfloat16)  # t41: "cuda:0 bf16[2, 32, 2048, 64]"
    # t44 = ltorch.add(t35, t40, alpha=None)  # t44: "cuda:0 f32[2, 32, 2048, 64]"
      # t44 = prims.add(t35, t40)  # t44: "cuda:0 f32[2, 32, 2048, 64]"
    # t45 = prims.convert_element_type(t44, dtypes.thunder.dtypes.bfloat16)  # t45: "cuda:0 bf16[2, 32, 2048, 64]"
    # t46 = prims.slice_prim(t18, [0, 0, 0, 0], [2, 32, 2048, 64], [1, 1, 1, 1])  # t46: "cuda:0 bf16[2, 32, 2048, 64]"
    # t47 = prims.slice_prim(t46, [0, 0, 0, 0], [2, 32, 2048, 32], [1, 1, 1, 1])  # t47: "cuda:0 bf16[2, 32, 2048, 32]"
    # t48 = prims.slice_prim(t46, [0, 0, 0, 32], [2, 32, 2048, 64], [1, 1, 1, 1])  # t48: "cuda:0 bf16[2, 32, 2048, 32]"
    # t49 = prims.convert_element_type(t48, dtypes.thunder.dtypes.float32)  # t49: "cuda:0 f32[2, 32, 2048, 32]"
    # t50 = prims.neg(t49)  # t50: "cuda:0 f32[2, 32, 2048, 32]"
    # t51 = prims.convert_element_type(t50, dtypes.thunder.dtypes.bfloat16)  # t51: "cuda:0 bf16[2, 32, 2048, 32]"
    # t53 = prims.cat((t51, t47), -1)  # t53: "cuda:0 bf16[2, 32, 2048, 64]"
    # t55 = prims.convert_element_type(t46, dtypes.thunder.dtypes.float32)  # t55: "cuda:0 f32[2, 32, 2048, 64]"
    # t57 = ltorch.mul(t55, t34)  # t57: "cuda:0 f32[2, 32, 2048, 64]"
      # t57 = prims.mul(t55, t34)  # t57: "cuda:0 f32[2, 32, 2048, 64]"
    # t58 = prims.convert_element_type(t57, dtypes.thunder.dtypes.bfloat16)  # t58: "cuda:0 bf16[2, 32, 2048, 64]"
    # t60 = prims.convert_element_type(t53, dtypes.thunder.dtypes.float32)  # t60: "cuda:0 f32[2, 32, 2048, 64]"
    # t62 = ltorch.mul(t60, t39)  # t62: "cuda:0 f32[2, 32, 2048, 64]"
      # t62 = prims.mul(t60, t39)  # t62: "cuda:0 f32[2, 32, 2048, 64]"
    # t63 = prims.convert_element_type(t62, dtypes.thunder.dtypes.bfloat16)  # t63: "cuda:0 bf16[2, 32, 2048, 64]"
    # t66 = ltorch.add(t57, t62, alpha=None)  # t66: "cuda:0 f32[2, 32, 2048, 64]"
      # t66 = prims.add(t57, t62)  # t66: "cuda:0 f32[2, 32, 2048, 64]"
    # t67 = prims.convert_element_type(t66, dtypes.thunder.dtypes.bfloat16)  # t67: "cuda:0 bf16[2, 32, 2048, 64]"
    # t68 = prims.slice_prim(t12, [0, 0, 0, 0], [2, 32, 2048, 0], [1, 1, 1, 1])  # t68: "cuda:0 bf16[2, 32, 2048, 0]"
    # t69 = prims.cat((t45, t68), -1)  # t69: "cuda:0 bf16[2, 32, 2048, 64]"
    # t70 = prims.slice_prim(t18, [0, 0, 0, 0], [2, 32, 2048, 0], [1, 1, 1, 1])  # t70: "cuda:0 bf16[2, 32, 2048, 0]"
    # t72 = prims.cat((t67, t70), -1)  # t72: "cuda:0 bf16[2, 32, 2048, 64]"
  return {'output': (t69, t72, t24), 'flat_args': [qkv, cos, sin], 'flat_output': (t69, t72, t24)}, ((t34, t39), (2,))

Thunder's torch.compile executor has 4 triton kernels:

image

Torch.compile has 2 kernels:

image

The reason could be that Thunder passes the decomposed operators to torch.compile and causes the fusion to be different, so performance is different.

cc @crcrpar @apaz-cli

lantiga commented 2 months ago

Yes definitely, we noticed this in the past and you probably hit the nail on the head about the root cause.

Thanks for opening the issue so we can track progress @kiya00

tfogal commented 2 months ago

triage review: