NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
245 stars 47 forks source link

alias analysis missing out opportunities on aliasing within fusion segments #2577

Open jjsjann123 opened 1 month ago

jjsjann123 commented 1 month ago

In full rope example, we have a fusion starting with something like this:

Inputs:
  T0_g[ iS467{1024}, iS468{8} ], __bfloat
  T1_g[ iS471{1024}, iS472{8} ], __bfloat
  T2_g[ iS463{2}, iS464{1024}, iS6{i7} ], __bfloat
Outputs:
  T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ], __bfloat
  T39_g[ iS390{2}, iS391{16}, iS188{1024}, iS189{32} ], __bfloat
  T58_g[ iS461{2}, iS462{16}, iS272{1024}, iS273{32} ], __bfloat

%kernel_math {
T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ] = view( T2_g[ iS463{2}, iS464{1024}, iS6{i7} ] )
T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ]
   = Set.Permute( T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ], cache_op=Streaming )
T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ]
   = slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {5, 6, 1} {0, 1024, 1} {0, 32, 1} } )
T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ]
   = Set( T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ], cache_op=Streaming )
T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] = expand( T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ], {2, 4, 4, 1024, 32} )
T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ] = view( T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] )
T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ]
   = slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {0, 4, 1} {0, 1024, 1} {0, 32, 1} } )
T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ] = view( T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ] )
i243 = 4 * 4;
i281 = fmin(i243, 16);
i283 = fmax(0, i281);
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]
   = slice( T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ], { {0, 2, 1} {0, i283, 1} {0, 1024, 1} {0, 8, 1} } )
T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
   = __bfloat2float(T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]);
T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ]
   = broadcast( T0_g[ iS467{1024}, iS468{8} ] )
T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ] = expand( T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ], {2, 16, 1024, 8} )
T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ]
   = __bfloat2float(T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ]);
T28_g[ iS346{2}, iS347{( 4 * 4 )}, iS141{1024}, iS142{8} ]
   = T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
   * T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ];
i387 = fmin(i243, 16);
i389 = fmax(0, i387);
T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i389, 1} {0, 1024, 1} {4, 8, 1} } )
T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ]
   = __bfloat2float(T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]);
T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]
   = -T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ];
T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ]
   = __float2bfloat(T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]);
T21_l[ iS360{2}, iS361{( 4 * 4 )}, iS362{1024}, iS113{8}rf ]
   = pad( T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ], {0, 0, 0, 0, 0, 0, 0, 4} )
i334 = fmin(i243, 16);
i336 = fmax(0, i334);
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
i448 = 0 + 4;
T22_g[ iS366{2}, iS367{( 4 * 4 )}, iS368{1024}, iS118{( ( 0 + 4 ) + 4 )}rf ]
   = pad( T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ], {0, 0, 0, 0, 0, 0, i448, 0} )
...

nvfuser unfortunately ended up segmenting on slice, which gives us a segment like this one below, where we are missing out on leveraging alias analysis but generated a kernel for a single slice op.

g{(pointwise)
inputs:
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ] __bfloat
outputs:
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ] __bfloat

i334 = fmin(i243, 16);
(29)
i336 = fmax(0, i334);
(30)
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
(34)
}

The issue here is that, we might not always segment on the right position. Current alias analysis runs as preseg pass, so it will miss out alias on segments and we are unfortunately generating kernels for the example above.

repro code

import torch
import thunder

# operations to prepare q, k before sending it into rope
def split_qkv(x, n_head, n_query_groups, head_size):
    (   
        B,
        T,
        C,
    ) = x.size()
    q_per_kv = n_head // n_query_groups
    total_qkv = q_per_kv + 2
    qkv = x.view(
        B, T, n_query_groups, total_qkv, head_size)
    qkv = qkv.permute(0, 2, 3, 1, 4)
    q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
    k = k.expand(B, n_query_groups, q_per_kv, T, head_size)
    v = v.expand(B, n_query_groups, q_per_kv, T, head_size)
    q = q.reshape(B, -1, T, head_size)
    k = k.reshape(B, -1, T, head_size)
    v = v.reshape(B, -1, T, head_size)
    return q, k, v

def rope_one_entry(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int) -> torch.Tensor:
    x_rope = x[..., : rope_n_elem]
    x1 = x_rope[..., : rope_n_elem // 2]  # (B, nh, T, hs/2)
    x2 = x_rope[..., rope_n_elem // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x_rope * cos) + (rotated * sin)
    roped.to(dtype=x.dtype)
    return torch.cat((roped, x[..., rope_n_elem :]), dim=-1)

# full rope and surrounding operations to compute q, k, v
def rope_it_all(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem):
    q, k, v = split_qkv(x, n_head, n_query_groups, head_size)
    q = rope_one_entry(q, cos, sin, rope_n_elem)
    k = rope_one_entry(k, cos, sin, rope_n_elem)
    return q, k, v

dtype = torch.bfloat16
device = "cuda"
bsz = 2
block_size = 1024
n_head = 16
head_size = 32
n_query_groups = 4
rope_n_elem = 8

x = torch.randn([bsz, block_size, (n_head + 2 * n_query_groups) * head_size], device=device, dtype=dtype)
cos = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
sin = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)

jit_rope = thunder.jit(rope_it_all, nv_enable_bookend=False)
q, k, v = jit_rope(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem)
print(thunder.last_traces(jit_rope)[-1])

Note this needs to run with bookend disabled in thunder, as in @wujingyue 's PR: https://github.com/Lightning-AI/lightning-thunder/pull/731

jacobhinkle commented 1 month ago

Could this be solved by accepting slice-only segments (and metadata-only reshapes) in the ExprEval scheduler?

jjsjann123 commented 1 month ago

Could this be solved by accepting slice-only segments (and metadata-only reshapes) in the ExprEval scheduler?

That's a great alternative.

My wild guess is that ExprEval scheduler is not populating/updating output TensorView definition to its consumer segments, so there's still that part of plumbing work to be done? (slice from ExprEval might update the output tensor to be not contiguous... unlike the default behavior where everything is contiguous...)

jacobhinkle commented 1 month ago

My wild guess is that ExprEval scheduler is not populating/updating output TensorView definition to its consumer segments, so there's still that part of plumbing work to be done? (slice from ExprEval might update the output tensor to be not contiguous... unlike the default behavior where everything is contiguous...)

Oh that's a good point; this is the real challenge in mixing the alias analysis with segmentation isn't it. We could envision a system where during segmentation, the a scheduler can optionally set some options on its output tensors when it accepts a segment, so ExprEval could set the contiguity flags at this point. If the allocation domain is already set on those outputs, this call could fail and the scheduler could either adjust or reject the segment, resetting the output TV.

wujingyue commented 1 month ago

I believe below is what's happening at this moment:

  1. AliasFinder does add slice's input and output to alias_to_source_, which is good. https://github.com/NVIDIA/Fuser/blob/1026f676b370a7ca5c6472daf1de061c8755bc87/csrc/alias_analysis.cpp#L273
  2. AliasAnalysisResult::finalize only sets the allocation domain of an intermediate tensor when the tensor can reach upwards a fusion input/output. https://github.com/NVIDIA/Fuser/blob/1026f676b370a7ca5c6472daf1de061c8755bc87/csrc/alias_analysis.cpp#L431-L452.

Once we fix (2), NoOpScheduler should be able to pick up the single-slice segment. Caveat: it may be risky to aggressively set the allocation domain of an intermediate tensor; I don't know how schedulers treat them.

jjsjann123 commented 1 month ago

where during segmentation, the a scheduler can optionally set some options on its output tensors when it accepts a segment,

would be great though it's a bigger hammer.

Caveat: it may be risky to aggressively set the allocation domain of an intermediate tensor; I don't know how schedulers treat them.

Yeah this part is also a bit concerning to me.... It felt like it's hard to make this decision on an unsegmented fusion.