NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Adding resize(PadOp) vectorization analysis #3321

Closed jjsjann123 closed 1 week ago

jjsjann123 commented 3 weeks ago

Adding conditional support of reszie in vectorization analysis. This PR allows vectorized load on PadOp directly without using cache load. This PR improves performance of generated kernel.

What's in this PR:

  1. Add propagation rule for resize in vectorization analysis. The propagation rule works as: i. For supported resize: a). project the resize op to the frontier and clear (frontier.begin(), resize_position); b). add projected extent of the new resize op as gcd(id_from, resize_op->leftExpand(), resize_op->rightExpand) ii. For unsupported resize: clear [frontier.begin(), resize_position]; no behavior change.

  2. updating TensorView::cacheAfter to opt-in a set of uses to cache while leaving other uses unchanged. Necessary for cases where inputs are used by PadOp as well as other operation that relies on cached load for vectorization.

Follow up to #3261. Work for supporting rope performance. design doc:

jjsjann123 commented 2 weeks ago

!test

jjsjann123 commented 2 weeks ago

~hmmm. this one isn't functional yet. The vectorize factor isn't computed correctly.~

nvm, had a small typo there. I should really add vectorize factor check in tests first.

Note for myself, still a couple test not working properly: PadAndCacheUses - validation fails here. VectorizePadNonInnermost - vectorization factor is using 4 and the kernel runs correctly. Maybe my own understanding isn't correct. double check the kernel.

jjsjann123 commented 2 weeks ago

!test

naoyam commented 2 weeks ago

!test --pybench

naoyam commented 2 weeks ago

Initiated testing with python benchmarks just in case.

jjsjann123 commented 2 weeks ago

Thanks, I'll address the issues you brought up as well as running through some real size problem so we get a taste of the perf impact. :bow:

jjsjann123 commented 1 week ago

!test --pybench

jjsjann123 commented 1 week ago

!test --pybench

jjsjann123 commented 1 week ago

Did a quick look at the perf. The end-2-end time looks very noisy. I'm a bit unsure about my measuring script so instead just did a nsys on the kernel time.

On A100 80GB PCIe, peak bandwidth is 2TB/s.

Something like this vvv.

import torch
import thunder
def rope_one_entry(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int) -> torch.Tensor:
    x_rope = x[..., : rope_n_elem]
    x1 = x_rope[..., : rope_n_elem // 2]  # (B, nh, T, hs/2)
    x2 = x_rope[..., rope_n_elem // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x_rope * cos) + (rotated * sin)
    [roped.to](http://roped.to/)(dtype=x.dtype)
    return torch.cat((roped, x[..., rope_n_elem :]), dim=-1)
dtype = torch.bfloat16
device = "cuda"
bsz = 256
block_size = 1024
n_head = 16
head_size = 32
n_query_groups = 4
rope_n_elem = 8
WARMPU_ITER = 5
MEASURE_ITER = 20
cos = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
sin = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
thunder_rope_one = thunder.jit(rope_one_entry, executors=("nvfuser",) , nv_enable_bookend=False)
x = torch.randn([bsz, n_head, block_size, head_size], device=device, dtype=dtype)
# ref full run
o_ref = rope_one_entry(x.float(), cos.float(), sin.float(), rope_n_elem).to(dtype=dtype)
l2_clear_buffer = torch.empty(80, 1024, 1024, dtype=torch.float, device="cuda")
# warm up
for i in range(WARMPU_ITER):
    o = thunder_rope_one(x, cos, sin, rope_n_elem)
# measurement
for i in range(MEASURE_ITER):
    l2_clear_buffer.zero_()
    o = thunder_rope_one(x, cos, sin, rope_n_elem)
assert(o.allclose(o_ref))
jjsjann123 commented 1 week ago

I think review comments have been addressed as well. CI was green before my benchmark. Ready for a final review.

jjsjann123 commented 1 week ago

!test --pybench

jjsjann123 commented 1 week ago

build failure seems to be flaky. I re-started the CI on that one and it passed internally.

Unfortunately it didn't update the github status here. Not a big issue but cc'ing @xwang233 in case this is something you are not aware of.

jjsjann123 commented 1 week ago

!test --pybench