NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
249 stars 49 forks source link

performance issue on dynamic shaped tensor #2795

Open jjsjann123 opened 1 month ago

jjsjann123 commented 1 month ago

The re-written rope example has quite different indexing when input q / cos / sin is defined with static or dynamic shapes.

I think this is coming from the inconsistent fusion definition. i.e. when we switch to have inputs defined with dynamic shape, the follow up slice operations aren't using the symbolic slice extent, so we cannot collapse indexing after the slice.

q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])

Currently we python API only allows static number at this point, but no Val* yet. This would also require definition in lowering to be updated as well.

I'm opening this issue for myself. I think it's worth re-writing the script below to see if we can get full perf back with symbolic shape when the proper definition is produced.

import torch
from nvfuser import FusionDefinition, DataType

bsz = 2
block_size = 1024
n_head = 16
head_size = 32
rope_n_elem = 8

def rope_fusion(fd: FusionDefinition) -> None:
    q = fd.define_tensor(
        #shape=[bsz, n_head, block_size, head_size],
        shape=[-1, -1, -1, -1],
        contiguity=[True, True, True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[3, 2, 1, 0],
    )
    cos = fd.define_tensor(
        #shape=[block_size, rope_n_elem],
        shape=[-1, -1],
        contiguity=[True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[1, 0],
    )
    sin = fd.define_tensor(
        #shape=[block_size, rope_n_elem],
        shape=[-1, -1],
        contiguity=[True, True],
        dtype=DataType.BFloat16,
        is_cpu=False,
        stride_order=[1, 0],
    )

    offset_0 = rope_n_elem // 2

    q_rope = fd.ops.slice(q, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_remainder = fd.ops.slice(q, start_indices=[0, 0, 0, rope_n_elem], end_indices=[bsz, n_head, block_size, head_size], strides=[1, 1, 1, 1])
    q_remainder = fd.ops.pad(q_remainder, list(reversed([0, 0, 0, 0, 0, 0, 0, rope_n_elem])))

    q_left = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
    q_left = fd.ops.pad(q_left, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))
    q_right = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_right = fd.ops.pad(q_right, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))

    # note that this is identical to q_left and q_right. We should be able to merge it back.
    q_left_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, 0], end_indices=[bsz, n_head, block_size, offset_0], strides=[1, 1, 1, 1])
    q_left_cos = fd.ops.pad(q_left_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem + offset_0, 0])))
    q_right_cos = fd.ops.slice(q_rope, start_indices=[0, 0, 0, offset_0], end_indices=[bsz, n_head, block_size, rope_n_elem], strides=[1, 1, 1, 1])
    q_right_cos = fd.ops.pad(q_right_cos, list(reversed([0, 0, 0, 0, 0, 0, head_size - rope_n_elem, rope_n_elem - offset_0])))

    # slice cos/sin
    cos_left = fd.ops.slice(cos, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
    cos_left = fd.ops.pad(cos_left, list(reversed([0, 0, head_size - offset_0, 0])))
    cos_left = fd.ops.broadcast_in_dim(cos_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
    cos_right = fd.ops.slice(cos, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
    cos_right = fd.ops.pad(cos_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
    cos_right = fd.ops.broadcast_in_dim(cos_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])

    sin_left = fd.ops.slice(sin, start_indices=[0, 0], end_indices=[block_size, offset_0], strides=[1, 1])
    sin_left = fd.ops.pad(sin_left, list(reversed([0, 0, head_size - offset_0, 0])))
    sin_left = fd.ops.broadcast_in_dim(sin_left, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])
    sin_right = fd.ops.slice(sin, start_indices=[0, offset_0], end_indices=[block_size, rope_n_elem], strides=[1, 1])
    sin_right = fd.ops.pad(sin_right, list(reversed([0, 0, head_size - rope_n_elem, offset_0])))
    sin_right = fd.ops.broadcast_in_dim(sin_right, shape=[1, 1, block_size, head_size], broadcast_dims=[2, 3])

    q0 = (-q_right) * sin_left + cos_left * q_left_cos
    q1 = q_left * sin_right + cos_right * q_right_cos
    q_out = q0 + q1 + q_remainder
    q_out = fd.ops.cast(q_out, dtype=DataType.BFloat16)
    q0 = fd.ops.cast(q0, dtype=DataType.BFloat16)

    fd.add_output(q_out)

with FusionDefinition() as fd:
    rope_fusion(fd)

inputs = [
    torch.randn((bsz, n_head, block_size, head_size), dtype=torch.bfloat16, device="cuda:0"),
    torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
    torch.randn((block_size, rope_n_elem), dtype=torch.bfloat16, device="cuda:0"),
]

o = fd.execute(inputs)[0]
jjsjann123 commented 1 month ago

cc'ing @zasdfgbnm , I don't think there's any actionable item needed on your side at this moment. I'll update this after I checked the performance with the new definition.

jacobhinkle commented 3 weeks ago

Here is a diff of the generated pointwise kernels on my 3090Ti:

--- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
 }

 } // namespace fused_reduction
-__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, Tensor<__bfloat, 4, 4> T48) {
+__global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
   NVFUSER_DEFINE_MAGIC_ZERO;
-  nvfuser_index_t i0;
-  i0 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * ((nvfuser_index_t)blockIdx.y));
-  nvfuser_index_t i1;
-  i1 = 8 * (i0 % 4);
-  nvfuser_index_t i2;
-  i2 = i0 / 4;
   nvfuser_index_t i3;
-  i3 = -4 + i1;
+  i3 = 8 * ((nvfuser_index_t)threadIdx.x);
   nvfuser_index_t i4;
-  i4 = i3 + (T26.alloc_stride[0LL] * i2);
+  i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
   nvfuser_index_t i5;
-  i5 = ((-4 + ((1024 * T6.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T6.alloc_stride[2LL] * i2);
+  i5 = i3 + i4;
   nvfuser_index_t i6;
-  i6 = i1 + (T14.alloc_stride[0LL] * i2);
+  i6 = 28 + T26.logical_size[1LL];
   nvfuser_index_t i7;
-  i7 = (((1024 * T10.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T10.alloc_stride[2LL] * i2);
+  i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
   nvfuser_index_t i8;
-  i8 = (((1024 * T8.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x)) + i1) + (T8.alloc_stride[2LL] * i2);
+  i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
   nvfuser_index_t i9;
-  i9 = i1 + (T22.alloc_stride[0LL] * i2);
+  i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
   nvfuser_index_t i10;
-  i10 = ((-8 + ((1024 * T4.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T4.alloc_stride[2LL] * i2);
+  i10 = 28 + T6.logical_size[3LL];
   nvfuser_index_t i11;
-  i11 = i3 + (T18.alloc_stride[0LL] * i2);
+  i11 = 28 + T14.logical_size[1LL];
   nvfuser_index_t i12;
-  i12 = ((-4 + ((1024 * T12.alloc_stride[2LL]) * ((nvfuser_index_t)blockIdx.x))) + i1) + (T12.alloc_stride[2LL] * i2);
+  i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
   nvfuser_index_t i13;
-  i13 = 8 * ((nvfuser_index_t)threadIdx.x);
+  i13 = 28 + T10.logical_size[3LL];
   nvfuser_index_t i14;
-  i14 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
+  i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
   nvfuser_index_t i15;
-  i15 = (i13 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i14;
-  bool b16;
-  b16 = (i13 + i14) < 32768;
-  if ((((i13 + 7) + i14) < 32768)) {
-    Array<__bfloat, 8, 8> T50;
+  i15 = 28 + T8.logical_size[3LL];
+  nvfuser_index_t i16;
+  i16 = 28 + T22.logical_size[1LL];
+  nvfuser_index_t i17;
+  i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
+  nvfuser_index_t i18;
+  i18 = 8 + T4.logical_size[3LL];
+  nvfuser_index_t i19;
+  i19 = 28 + T18.logical_size[1LL];
+  nvfuser_index_t i20;
+  i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
+  nvfuser_index_t i21;
+  i21 = 28 + T12.logical_size[3LL];
+  nvfuser_index_t i22;
+  i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+  nvfuser_index_t i23;
+  i23 = 24 * T10.logical_size[2LL];
+  nvfuser_index_t i24;
+  i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
+  bool b25;
+  b25 = i5 < i24;
+  nvfuser_index_t i26;
+  i26 = 28 * T10.logical_size[2LL];
+  bool b27;
+  b27 = i5 < (i26 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
+  bool b28;
+  b28 = i5 < (i26 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
+  bool b29;
+  b29 = i5 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i26);
+  bool b30;
+  b30 = i5 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
+  bool b31;
+  b31 = i5 < (i26 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
+  bool b32;
+  b32 = i5 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
+  bool b33;
+  b33 = i5 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+  if ((((i3 + 7) + i4) < i24)) {
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i18;
-      i18 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i35;
+      i35 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
       T27[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))] : 0.0000e+00;
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
       T7[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i35 / i10))) + (i35 % i10))] : 0.0000e+00;
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
@@ -10766,23 +10794,23 @@
       __bfloat T15[1];
       T15[0] = 0;
       T15[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i35 / i11)) + (i35 % i11))] : 0.0000e+00;
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
       T11[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i35 / i13))) + (i35 % i13))] : 0.0000e+00;
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
@@ -10790,7 +10818,7 @@
       __bfloat T9[1];
       T9[0] = 0;
       T9[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i35 / i15))) + (i35 % i15))] : 0.0000e+00;
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
@@ -10800,16 +10828,16 @@
       __bfloat T23[1];
       T23[0] = 0;
       T23[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i35 / i16)) + (i35 % i16))] : 0.0000e+00;
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
@@ -10817,30 +10845,30 @@
       __bfloat T5[1];
       T5[0] = 0;
       T5[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i35 / i18))) + (i35 % i18))] : 0.0000e+00;
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
       T19[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i35 / i19))) + (i35 % i19))] : 0.0000e+00;
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
       T13[0]
-         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i18)] : 0.0000e+00;
+         = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i35 / i21))) + (i35 % i21))] : 0.0000e+00;
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -10861,78 +10889,78 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
   } else {
-    Array<__bfloat, 8, 8> T50;
+    Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
-      nvfuser_index_t i19;
-      i19 = i17 + nvfuser_zero;
+    for(nvfuser_index_t i34 = 0; i34 < 8; ++i34) {
+      nvfuser_index_t i36;
+      i36 = i5 + (i34 + nvfuser_zero);
       __bfloat T27[1];
       T27[0] = 0;
-      if (b16) {
+      if (b25) {
         T27[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T26[(i4 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T26.logical_size[1LL] + 4) + 24)) - 4) < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + (i36 % i6))] : 0.0000e+00;
       }
       __bfloat T28[1];
       T28[0]
          = T27[0];
-      __bfloat T29[1];
-      T29[0]
+      __bfloat T52[1];
+      T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
-      if (b16) {
+      if (b27) {
         T7[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T6[(i5 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T6.logical_size[3LL] + 4) + 24)) - 4) < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + (i36 % i10))] : 0.0000e+00;
       }
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
-         = __bfloat2float(T29[0]);
+         = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
-      if (b16) {
+      if (b28) {
         T15[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T14[(i6 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T14.logical_size[1LL] + 28)) - 0) < T14.logical_size[1LL])) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + (i36 % i11))] : 0.0000e+00;
       }
       __bfloat T16[1];
       T16[0]
          = T15[0];
-      __bfloat T17[1];
-      T17[0]
+      __bfloat T51[1];
+      T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
-      if (b16) {
+      if (b29) {
         T11[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T10[(i7 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T10.logical_size[3LL] + 28)) - 0) < T10.logical_size[3LL])) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + (i36 % i13))] : 0.0000e+00;
       }
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
-         = __bfloat2float(T17[0]);
+         = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
-      if (b16) {
+      if (b30) {
         T9[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T8[(i8 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T8.logical_size[3LL] + 28)) - 0) < T8.logical_size[3LL])) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + (i36 % i15))] : 0.0000e+00;
       }
       float T30[1];
       T30[0]
@@ -10942,56 +10970,56 @@
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
-      if (b16) {
+      if (b31) {
         T23[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 0) < 4)) ? T22[(i9 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T22.logical_size[1LL] + 28)) - 0) < T22.logical_size[1LL])) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + (i36 % i16))] : 0.0000e+00;
       }
       __bfloat T24[1];
       T24[0]
          = T23[0];
-      __bfloat T25[1];
-      T25[0]
+      __bfloat T50[1];
+      T50[0]
          = T24[0];
       float T32[1];
       T32[0]
-         = __bfloat2float(T25[0]);
+         = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
-      if (b16) {
+      if (b32) {
         T5[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 8) < 24)) ? T4[(i10 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % (T4.logical_size[3LL] + 8)) - 8) < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + (i36 % i18))] : 0.0000e+00;
       }
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
-      if (b16) {
+      if (b33) {
         T19[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T18[(i11 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T18.logical_size[1LL] + 4) + 24)) - 4) < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + (i36 % i19))] : 0.0000e+00;
       }
       __bfloat T20[1];
       T20[0]
          = T19[0];
-      __bfloat T21[1];
-      T21[0]
+      __bfloat T53[1];
+      T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
-      if (b16) {
+      if (b30) {
         T13[0]
-           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i17 + nvfuser_zero)) % 32) - 4) < 4)) ? T12[(i12 + i19)] : 0.0000e+00;
+           = ((((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) >= 0) && (((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 8) + (i34 + nvfuser_zero)) % ((T12.logical_size[3LL] + 4) + 24)) - 4) < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + (i36 % i21))] : 0.0000e+00;
       }
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
-         = __bfloat2float(T21[0]);
+         = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
@@ -11012,12 +11040,12 @@
       T47[0]
         = T45[0]
         + T46[0];
-      T50[i17]
+      T54[i34]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    if (b16) {
-      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i15], &T50[0]);
+    if ((i5 < 32768)) {
+      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
     }
   }
 }

The static kernel is just using the commented lines in the repro posted above. It achieves about 5x higher BW compared to dynamic (runs in 8 us vs 38).

There is more in the preamble for dynamic shapes, but inside the loops the expressions also have slightly more going on. For example (zoomed in and inserted line breaks):

-- static.cu   2024-08-19 10:23:21.977784983 -0400
+++ dynamic.cu  2024-08-19 10:23:58.741144923 -0400
@@ -10697,68 +10697,96 @@
       __bfloat T27[1];
       T27[0] = 0;
       T27[0] = ((((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
                   4) >= 0) &&
                 (((((((((nvfuser_index_t)blockIdx.y) *
                        ((nvfuser_index_t)blockDim.x)) +
                       ((nvfuser_index_t)threadIdx.x)) *
                      8) +
-                    (i17 + nvfuser_zero)) %
-                   32) -
-                  4) < 4))
-          ? T26[(i4 + i18)]
+                    (i34 + nvfuser_zero)) %
+                   ((T26.logical_size[1LL] + 4) + 24)) -
+                  4) < T26.logical_size[1LL]))
+          ? T26[((-4 + (T26.alloc_stride[0LL] * (i35 / i6))) + (i35 % i6))]
           : 0.0000e+00;
       __bfloat T28[1];
       T28[0] = T27[0];

In this context i35 is a loop index, so we might not be able to simplify the last diff line much, but we also are not hoisting (T26.logical_size[1LL] + 4) + 24) for some reason...

jacobhinkle commented 3 weeks ago

As for the preamble, there are lots of max and mins in the dynamic kernel, which could be avoided using #511 (I'm looking at updating this). As discussed last week, we could temporarily make all sliced input extents and all slice ranges constant at concretization, which I think would give us a kernel similar to static.cu above.

jjsjann123 commented 3 weeks ago

For our own sanity, here's a simplified cpp test. Indexing isn't being simplified even when the slice is passing in the correct extent val.

For @jacobhinkle 's WAR in #511. I understand it as that, we wouldn't need this form of definition and the performance in the original python repro shouldn't regress with dynamic shape.

Creating this repro for @zasdfgbnm , I'm assuming the definition here should be enough to tell us that we are doing two non-overlapping slice and indexing maybe could be simplified, even without concretization...

TEST_F(NVFuserTest, DynamicShapedPad) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  auto zero = fusion->zeroVal();
  auto one = fusion->oneVal();

  std::vector<int64_t> shape{32, 1024, 16};

#if 0
  auto tv0 = makeContigConcreteTensor(shape);
  auto dim0 = IrBuilder::create<Val>(32, DataType::Index);
  auto dim1 = IrBuilder::create<Val>(1024, DataType::Index);
  auto dim2 = IrBuilder::create<Val>(16, DataType::Index);
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = IrBuilder::create<Val>(8, DataType::Index);
#else
  auto tv0 = makeContigTensor(3);
  auto dim0 = tv0->axis(0)->extent();
  auto dim1 = tv0->axis(1)->extent();
  auto dim2 = tv0->axis(2)->extent();
  auto val_slice = IrBuilder::create<Val>(8, DataType::Index);
  auto val_remain = sub(dim2, val_slice);
#endif
  Slice slice_dim_0{zero, dim0, one};
  Slice slice_dim_1{zero, dim1, one};
  Slice slice_dim_2_l{zero, val_slice, one};
  Slice slice_dim_2_r{val_slice, dim2, one};

  std::vector<Slice> slice_l_ind = {slice_dim_0, slice_dim_1, slice_dim_2_l};
  std::vector<Slice> slice_r_ind = {slice_dim_0, slice_dim_1, slice_dim_2_r};

  auto slice_l = slice(tv0, slice_l_ind);
  auto slice_r = slice(tv0, slice_r_ind);

  fusion->addInput(tv0);

  auto rope_l = pad(slice_r, {zero, val_slice, zero, zero, zero, zero});
  // avoid segmentation.
  // auto rope_r = pad(neg(slice_l), {val_remain, zero, zero, zero, zero, zero});
  auto rope_r = neg(pad(slice_l, {val_remain, zero, zero, zero, zero, zero}));

  auto o = add(rope_l, rope_r);

  fusion->addOutput(o);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn(shape, options);
  std::vector<c10::IValue> aten_inputs({t0});

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

  testValidate(
      executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__);
}
zasdfgbnm commented 3 weeks ago

Thanks @jjsjann123 for providing this repro. I do believe there are some opportunities to simplify symbolicly here:

For example, the predicate of T3 = pad(T2) looks like:

Static shape:

-8 + (4 * (threadIdx.x % 4)) < -i0

Dynamic shape:

  ((((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) %
   ((8 * T1.logical_size[1LL]) +
    (T1.logical_size[1LL] * T2.logical_size[2LL]))) %
          (8 + T2.logical_size[2LL]) <
      T2.logical_size[2LL]

For the dynamic shape case, note that let:

a = ((4 * threadIdx.x) + (512 * blockIdx.x)) + i0;
b = T1.logical_size[1LL];
c = 8 + T2.logical_size[2LL];

then the predicate is:

a % (b * c) % c < T2.logical_size[2LL]

which clearly can be simplified as:

a % c = (((4 * threadIdx.x) + (512 * blockIdx.x)) + i0) % (8 + T2.logical_size[2LL]) < T2.logical_size[2LL]

Clearly not as good as the static shape case, but still an improve.

Kernel diff: https://www.diffchecker.com/vK2pS9ak/

jjsjann123 commented 3 weeks ago

Yeah if there's no low-hanging fruits, I don't think it matters at this point, since we are going down the path with @jacobhinkle 's plan on static shapes during concretization.

We can revisit this if we decide to push it further afterwards.

jacobhinkle commented 3 weeks ago

static shapes during concretization.

BTW in implementing this I just noticed that a lot of the resizes are dynamic but for the provided inputs are actually trivial:

    ?S9{( fmax(0, ( fmin(i0, 2) )) )}rf (index=0) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S11{( fmax(0, ( fmin(i1, 16) )) )}rf (index=1) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S13{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=2) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S15{( fmax(0, ( fmin(i3, 8) )) )}rf (index=3) is a resize of input extent 32 with left_pad=0 and right_pad=-24
    ?S43{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=4) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S45{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=5) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S47{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=6) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S49{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=7) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S112{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=8) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S114{( fmax(0, ( fmin(i8, 4) )) )}rf (index=9) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S82{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=10) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S84{( fmax(0, ( fmin(i6, 4) )) )}rf (index=11) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S56{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=12) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S58{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=13) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S60{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=14) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S62{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=15) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S30{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=16) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S32{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=17) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S34{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=18) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S36{( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 4) )) )}rf (index=19) is a resize of input extent 8 with left_pad=0 and right_pad=-4
    ?S127{( fmax(0, ( fmin(i7, 1024) )) )}rf (index=20) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S129{( ( fmax(4, ( fmin(i8, 8) )) ) - 4 )}rf (index=21) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S97{( fmax(0, ( fmin(i5, 1024) )) )}rf (index=22) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S99{( ( fmax(4, ( fmin(i6, 8) )) ) - 4 )}rf (index=23) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S69{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}rf (index=24) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S71{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 16) )) ), 16) )) )}rf (index=25) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S73{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 1024) )) ), 1024) )) )}rf (index=26) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S75{( ( fmax(4, ( fmin(( fmax(0, ( fmin(i3, 8) )) ), 8) )) ) - 4 )}rf (index=27) is a resize of input extent 8 with left_pad=-4 and right_pad=0
    ?S17{( fmax(0, ( fmin(i0, 2) )) )}rf (index=28) is a resize of input extent 2 with left_pad=0 and right_pad=0
    ?S19{( fmax(0, ( fmin(i1, 16) )) )}rf (index=29) is a resize of input extent 16 with left_pad=0 and right_pad=0
    ?S21{( fmax(0, ( fmin(i2, 1024) )) )}rf (index=30) is a resize of input extent 1024 with left_pad=0 and right_pad=0
    ?S23{( ( fmax(8, ( fmin(i3, 32) )) ) - 8 )}rf (index=31) is a resize of input extent 32 with left_pad=-8 and right_pad=0

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

jjsjann123 commented 3 weeks ago

By my count 21 out of these 32 resized axes are not actually resized at all. Using static shapes, not only will the expressions be simpler, but we will catch every one of these trivial resizes and we will not predicate that access. I'll have that as part of concretization in a PR soon.

Thanks for pointing out that. Yes that's expected, this is one of the mismatch on thunder's static program. slice with [..., ] is baked in as constants as well... We need to re-write those logic later.

jacobhinkle commented 2 weeks ago

As of #2714, with the repro in the description of this issue, we went from 35 us on main to 13 us. Here is the diff of the generated kernel:

 __global__ void nvfuser_pointwise_f0_c1_r0_g6(Tensor<__bfloat, 4, 4> T10, Tensor<__bfloat, 4, 4> T8, Tensor<__bfloat, 2, 2> T22, Tensor<__bfloat, 2, 2> T14, Tensor<__bfloat, 2, 2> T18, Tensor<__bfloat, 4, 4> T12, Tensor<__bfloat, 2, 2> T26, Tensor<__bfloat, 4, 4> T4, Tensor<__bfloat, 4, 4> T6, nvfuser_index_t i0, nvfuser_index_t i1, nvfuser_index_t i2, Tensor<__bfloat, 4, 4> T48) {
   NVFUSER_DEFINE_MAGIC_ZERO;
   nvfuser_index_t i3;
   i3 = 8 * ((nvfuser_index_t)threadIdx.x);
   nvfuser_index_t i4;
   i4 = (8 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.y);
   nvfuser_index_t i5;
   i5 = i3 + i4;
   nvfuser_index_t i6;
-  i6 = 28 + T26.logical_size[1LL];
+  i6 = 28 + T18.logical_size[1LL];
   nvfuser_index_t i7;
   i7 = ((nvfuser_index_t)blockIdx.x) / T10.logical_size[1LL];
   nvfuser_index_t i8;
   i8 = ((nvfuser_index_t)blockIdx.x) % T10.logical_size[1LL];
   nvfuser_index_t i9;
   i9 = (-4 + (T6.alloc_stride[0LL] * i7)) + (T6.alloc_stride[1LL] * i8);
   nvfuser_index_t i10;
-  i10 = 28 + T6.logical_size[3LL];
+  i10 = 28 + T10.logical_size[3LL];
   nvfuser_index_t i11;
-  i11 = 28 + T14.logical_size[1LL];
+  i11 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
   nvfuser_index_t i12;
-  i12 = (T10.alloc_stride[0LL] * i7) + (T10.alloc_stride[1LL] * i8);
+  i12 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
   nvfuser_index_t i13;
-  i13 = 28 + T10.logical_size[3LL];
+  i13 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
   nvfuser_index_t i14;
-  i14 = (T8.alloc_stride[0LL] * i7) + (T8.alloc_stride[1LL] * i8);
+  i14 = 8 + T4.logical_size[3LL];
   nvfuser_index_t i15;
-  i15 = 28 + T8.logical_size[3LL];
+  i15 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
   nvfuser_index_t i16;
-  i16 = 28 + T22.logical_size[1LL];
-  nvfuser_index_t i17;
-  i17 = (-8 + (T4.alloc_stride[0LL] * i7)) + (T4.alloc_stride[1LL] * i8);
-  nvfuser_index_t i18;
-  i18 = 8 + T4.logical_size[3LL];
-  nvfuser_index_t i19;
-  i19 = 28 + T18.logical_size[1LL];
-  nvfuser_index_t i20;
-  i20 = (-4 + (T12.alloc_stride[0LL] * i7)) + (T12.alloc_stride[1LL] * i8);
-  nvfuser_index_t i21;
-  i21 = 28 + T12.logical_size[3LL];
-  nvfuser_index_t i22;
-  i22 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
+  i16 = (i3 + (32768 * ((nvfuser_index_t)blockIdx.x))) + i4;
-  nvfuser_index_t i23;
-  i23 = 24 * T10.logical_size[2LL];
-  nvfuser_index_t i24;
-  i24 = ((max(4, (min(i0, 8)))) * T10.logical_size[2LL]) + i23;
-  nvfuser_index_t i25;
-  i25 = (7 + i3) + i4;
-  bool b26;
-  b26 = i25 < i24;
-  nvfuser_index_t i27;
-  i27 = 28 * T10.logical_size[2LL];
-  bool b28;
-  b28 = i25 < (i27 + (T10.logical_size[2LL] * T6.logical_size[3LL]));
-  bool b29;
-  b29 = i25 < (i27 + (T10.logical_size[2LL] * T14.logical_size[1LL]));
-  bool b30;
-  b30 = i25 < ((T10.logical_size[2LL] * T10.logical_size[3LL]) + i27);
-  bool b31;
+  bool b17;
-  b31 = i25 < (((max(4, (min((max(0LL, (min(i1, 8)))), 8)))) * T10.logical_size[2LL]) + i23);
-  bool b32;
-  b32 = i25 < (i27 + (T10.logical_size[2LL] * T22.logical_size[1LL]));
-  bool b33;
-  b33 = i25 < ((max(8, (min(i1, 32)))) * T10.logical_size[2LL]);
-  bool b34;
-  b34 = i25 < (((max(4, (min(i2, 8)))) * T10.logical_size[2LL]) + i23);
+  b17 = ((7 + i3) + i4) < 32768;
-  if ((((i3 + 7) + i4) < i24)) {
+  if ((((i3 + 7) + i4) < 32768)) {
     Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+    for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
-      nvfuser_index_t i36;
+      nvfuser_index_t i19;
-      i36 = i5 + (i35 + nvfuser_zero);
+      i19 = i5 + (i18 + nvfuser_zero);
+      nvfuser_index_t i20;
+      i20 = i19 % i6;
+      nvfuser_index_t i21;
+      i21 = -4 + i20;
+      nvfuser_index_t i22;
+      i22 = i19 / i6;
+      bool b23;
+      b23 = (i21 >= 0) && (i21 < T18.logical_size[1LL]);
+      nvfuser_index_t i24;
+      i24 = i19 % i10;
+      nvfuser_index_t i25;
+      i25 = i19 / i10;
+      bool b26;
+      b26 = i24 < T10.logical_size[3LL];
-      nvfuser_index_t i37;
+      nvfuser_index_t i27;
-      i37 = i36 % i6;
+      i27 = i19 % i14;
-      nvfuser_index_t i38;
+      nvfuser_index_t i28;
-      i38 = -4 + i37;
+      i28 = -8 + i27;
-      nvfuser_index_t i39;
-      i39 = i36 % i10;
-      nvfuser_index_t i40;
-      i40 = -4 + i39;
-      nvfuser_index_t i41;
-      i41 = i36 % i11;
-      nvfuser_index_t i42;
-      i42 = i36 % i13;
-      nvfuser_index_t i43;
-      i43 = i36 % i15;
-      nvfuser_index_t i44;
-      i44 = i36 % i16;
-      nvfuser_index_t i45;
-      i45 = i36 % i18;
-      nvfuser_index_t i46;
-      i46 = -8 + i45;
-      nvfuser_index_t i47;
-      i47 = i36 % i19;
-      nvfuser_index_t i48;
-      i48 = -4 + i47;
-      nvfuser_index_t i49;
-      i49 = i36 % i21;
-      nvfuser_index_t i50;
-      i50 = -4 + i49;
       __bfloat T27[1];
       T27[0] = 0;
       T27[0]
-         = ((i38 >= 0) && (i38 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i36 / i6))) + i37)] : 0.0000e+00;
+         = b23 ? T26[(i21 + (T26.alloc_stride[0LL] * i22))] : 0.0000e+00;
       __bfloat T28[1];
       T28[0]
          = T27[0];
       __bfloat T52[1];
       T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
       T7[0]
-         = ((i40 >= 0) && (i40 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i36 / i10))) + i39)] : 0.0000e+00;
+         = b23 ? T6[((i9 + i20) + (T6.alloc_stride[2LL] * i22))] : 0.0000e+00;
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
          = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
       T15[0]
-         = (i41 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i36 / i11)) + i41)] : 0.0000e+00;
+         = b26 ? T14[(i24 + (T14.alloc_stride[0LL] * i25))] : 0.0000e+00;
       __bfloat T16[1];
       T16[0]
          = T15[0];
       __bfloat T51[1];
       T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
       T11[0]
-         = (i42 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i36 / i13))) + i42)] : 0.0000e+00;
+         = b26 ? T10[((i11 + (T10.alloc_stride[2LL] * i25)) + i24)] : 0.0000e+00;
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
          = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
       T9[0]
-         = (i43 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i36 / i15))) + i43)] : 0.0000e+00;
+         = b26 ? T8[((i12 + i24) + (T8.alloc_stride[2LL] * i25))] : 0.0000e+00;
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
       float T31[1];
       T31[0]
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
       T23[0]
-         = (i44 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i36 / i16)) + i44)] : 0.0000e+00;
+         = b26 ? T22[(i24 + (T22.alloc_stride[0LL] * i25))] : 0.0000e+00;
       __bfloat T24[1];
       T24[0]
          = T23[0];
       __bfloat T50[1];
       T50[0]
          = T24[0];
       float T32[1];
       T32[0]
          = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
       T5[0]
-         = ((i46 >= 0) && (i46 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i36 / i18))) + i45)] : 0.0000e+00;
+         = ((i28 >= 0) && (i28 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i19 / i14))) + i27)] : 0.0000e+00;
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
       T19[0]
-         = ((i48 >= 0) && (i48 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i36 / i19))) + i47)] : 0.0000e+00;
+         = b23 ? T18[((-4 + (T18.alloc_stride[0LL] * i22)) + i20)] : 0.0000e+00;
       __bfloat T20[1];
       T20[0]
          = T19[0];
       __bfloat T53[1];
       T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
       T13[0]
-         = ((i50 >= 0) && (i50 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i36 / i21))) + i49)] : 0.0000e+00;
+         = b23 ? T12[((i15 + i20) + (T12.alloc_stride[2LL] * i22))] : 0.0000e+00;
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
          = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
         * T42[0];
       float T44[1];
       T44[0]
         = T40[0]
         + T43[0];
       float T37[1];
       T37[0]
         = T33[0]
         + T36[0];
       float T45[1];
       T45[0]
         = T37[0]
         + T44[0];
       float T47[1];
       T47[0]
         = T45[0]
         + T46[0];
-      T54[i35]
+      T54[i18]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
   } else {
     Array<__bfloat, 8, 8> T54;
     #pragma unroll
-    for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
+    for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
+      nvfuser_index_t i29;
+      i29 = i5 + (i18 + nvfuser_zero);
+      nvfuser_index_t i30;
+      i30 = i29 % i6;
-      nvfuser_index_t i51;
+      nvfuser_index_t i31;
-      i51 = i5 + (i35 + nvfuser_zero);
+      i31 = -4 + i30;
-      nvfuser_index_t i52;
+      nvfuser_index_t i32;
-      i52 = i51 % i6;
+      i32 = i29 / i6;
+      bool b33;
+      b33 = (i31 >= 0) && (i31 < T18.logical_size[1LL]);
-      nvfuser_index_t i53;
+      nvfuser_index_t i34;
-      i53 = -4 + i52;
+      i34 = i29 % i10;
-      nvfuser_index_t i54;
+      nvfuser_index_t i35;
-      i54 = i51 % i10;
+      i35 = i29 / i10;
-      nvfuser_index_t i55;
-      i55 = -4 + i54;
-      nvfuser_index_t i56;
-      i56 = i51 % i11;
+      bool b36;
+      b36 = i34 < T10.logical_size[3LL];
-      nvfuser_index_t i57;
+      nvfuser_index_t i37;
-      i57 = i51 % i13;
+      i37 = i29 % i14;
-      nvfuser_index_t i58;
+      nvfuser_index_t i38;
-      i58 = i51 % i15;
-      nvfuser_index_t i59;
-      i59 = i51 % i16;
-      nvfuser_index_t i60;
-      i60 = i51 % i18;
-      nvfuser_index_t i61;
-      i61 = -8 + i60;
+      i38 = -8 + i37;
-      nvfuser_index_t i62;
-      i62 = i51 % i19;
-      nvfuser_index_t i63;
-      i63 = -4 + i62;
-      nvfuser_index_t i64;
-      i64 = i51 % i21;
-      nvfuser_index_t i65;
-      i65 = -4 + i64;
       __bfloat T27[1];
       T27[0] = 0;
-      if (b26) {
+      if (b17) {
         T27[0]
-           = ((i53 >= 0) && (i53 < T26.logical_size[1LL])) ? T26[((-4 + (T26.alloc_stride[0LL] * (i51 / i6))) + i52)] : 0.0000e+00;
+           = b33 ? T26[(i31 + (T26.alloc_stride[0LL] * i32))] : 0.0000e+00;
       }
       __bfloat T28[1];
       T28[0]
          = T27[0];
       __bfloat T52[1];
       T52[0]
          = T28[0];
       __bfloat T7[1];
       T7[0] = 0;
-      if (b28) {
+      if (b17) {
         T7[0]
-           = ((i55 >= 0) && (i55 < T6.logical_size[3LL])) ? T6[((i9 + (T6.alloc_stride[2LL] * (i51 / i10))) + i54)] : 0.0000e+00;
+           = b33 ? T6[((i9 + i30) + (T6.alloc_stride[2LL] * i32))] : 0.0000e+00;
       }
       float T38[1];
       T38[0]
          = __bfloat2float(T7[0]);
       float T39[1];
       T39[0]
          = __bfloat2float(T52[0]);
       float T40[1];
       T40[0]
         = T38[0]
         * T39[0];
       __bfloat T15[1];
       T15[0] = 0;
-      if (b29) {
+      if (b17) {
         T15[0]
-           = (i56 < T14.logical_size[1LL]) ? T14[((T14.alloc_stride[0LL] * (i51 / i11)) + i56)] : 0.0000e+00;
+           = b36 ? T14[(i34 + (T14.alloc_stride[0LL] * i35))] : 0.0000e+00;
       }
       __bfloat T16[1];
       T16[0]
          = T15[0];
       __bfloat T51[1];
       T51[0]
          = T16[0];
       __bfloat T11[1];
       T11[0] = 0;
-      if (b30) {
+      if (b17) {
         T11[0]
-           = (i57 < T10.logical_size[3LL]) ? T10[((i12 + (T10.alloc_stride[2LL] * (i51 / i13))) + i57)] : 0.0000e+00;
+           = b36 ? T10[((i11 + (T10.alloc_stride[2LL] * i35)) + i34)] : 0.0000e+00;
       }
       float T35[1];
       T35[0]
          = __bfloat2float(T11[0]);
       float T34[1];
       T34[0]
          = __bfloat2float(T51[0]);
       float T36[1];
       T36[0]
         = T34[0]
         * T35[0];
       __bfloat T9[1];
       T9[0] = 0;
-      if (b31) {
+      if (b17) {
         T9[0]
-           = (i58 < T8.logical_size[3LL]) ? T8[((i14 + (T8.alloc_stride[2LL] * (i51 / i15))) + i58)] : 0.0000e+00;
+           = b36 ? T8[((i12 + i34) + (T8.alloc_stride[2LL] * i35))] : 0.0000e+00;
       }
       float T30[1];
       T30[0]
          = __bfloat2float(T9[0]);
       float T31[1];
       T31[0]
          = -T30[0];
       __bfloat T23[1];
       T23[0] = 0;
-      if (b32) {
+      if (b17) {
         T23[0]
-           = (i59 < T22.logical_size[1LL]) ? T22[((T22.alloc_stride[0LL] * (i51 / i16)) + i59)] : 0.0000e+00;
+           = b36 ? T22[(i34 + (T22.alloc_stride[0LL] * i35))] : 0.0000e+00;
       }
       __bfloat T24[1];
       T24[0]
          = T23[0];
       __bfloat T50[1];
       T50[0]
          = T24[0];
       float T32[1];
       T32[0]
          = __bfloat2float(T50[0]);
       float T33[1];
       T33[0]
         = T31[0]
         * T32[0];
       __bfloat T5[1];
       T5[0] = 0;
-      if (b33) {
+      if (b17) {
         T5[0]
-           = ((i61 >= 0) && (i61 < T4.logical_size[3LL])) ? T4[((i17 + (T4.alloc_stride[2LL] * (i51 / i18))) + i60)] : 0.0000e+00;
+           = ((i38 >= 0) && (i38 < T4.logical_size[3LL])) ? T4[((i13 + (T4.alloc_stride[2LL] * (i29 / i14))) + i37)] : 0.0000e+00;
       }
       float T46[1];
       T46[0]
          = __bfloat2float(T5[0]);
       __bfloat T19[1];
       T19[0] = 0;
-      if (b34) {
+      if (b17) {
         T19[0]
-           = ((i63 >= 0) && (i63 < T18.logical_size[1LL])) ? T18[((-4 + (T18.alloc_stride[0LL] * (i51 / i19))) + i62)] : 0.0000e+00;
+           = b33 ? T18[((-4 + (T18.alloc_stride[0LL] * i32)) + i30)] : 0.0000e+00;
       }
       __bfloat T20[1];
       T20[0]
          = T19[0];
       __bfloat T53[1];
       T53[0]
          = T20[0];
       __bfloat T13[1];
       T13[0] = 0;
-      if (b31) {
+      if (b17) {
         T13[0]
-           = ((i65 >= 0) && (i65 < T12.logical_size[3LL])) ? T12[((i20 + (T12.alloc_stride[2LL] * (i51 / i21))) + i64)] : 0.0000e+00;
+           = b33 ? T12[((i15 + i30) + (T12.alloc_stride[2LL] * i32))] : 0.0000e+00;
       }
       float T42[1];
       T42[0]
          = __bfloat2float(T13[0]);
       float T41[1];
       T41[0]
          = __bfloat2float(T53[0]);
       float T43[1];
       T43[0]
         = T41[0]
         * T42[0];
       float T44[1];
       T44[0]
         = T40[0]
         + T43[0];
       float T37[1];
       T37[0]
         = T33[0]
         + T36[0];
       float T45[1];
       T45[0]
         = T37[0]
         + T44[0];
       float T47[1];
       T47[0]
         = T45[0]
         + T46[0];
-      T54[i35]
+      T54[i18]
          = __float2bfloat(T47[0]);
     }
     NVFUSER_UPDATE_MAGIC_ZERO;
-    if ((i25 < 32768)) {
+    if (b17) {
-      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i22], &T54[0]);
+      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T48[i16], &T54[0]);
     }
   }
 }

I think there are simpler conditions for the padding predicates, and they're getting hoisted. With static shapes, this runtime can be improved further to about 8 us.

BTW looking at larger problem size (changing head size from 32 to 256 and bsz from 2 to 256), we have 28.1 ms vs 10.7 ms (2.6x speedup similar to the smaller problem size).

jjsjann123 commented 1 week ago

patch this from thunder side per our earlier conversation: https://github.com/Lightning-AI/lightning-thunder/pull/1096 That PR would effectively change program to be statically shaped in thunder. Until we pull through dynamic shape support in thunder.