pytorch / FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
Other
1.12k stars 451 forks source link

Make CUTLASS rowwise fp8 faster #2764

Open lw opened 3 weeks ago

lw commented 3 weeks ago

Summary: By telling CUTLASS to output in column-major (somehow it's faster) and transposing the inputs so that the end result is the same.

Here are the benchmark results for the sizes I was interested in:

m=4176 n=12288 k=4096 cuBLAS: min=260.414, max=316.191, median=310.302, mean=306.742 +/- 11.8098 FBGEMM, before: min=359.87, max=379.134, median=373.949, mean=372.939 +/- 3.81074 FBGEMM, now: min=268.318, max=320.51, median=317.47, mean=313.336 +/- 11.3555

m=4176 n=4096 k=12288 cuBLAS: min=269.567, max=334.398, median=330.399, mean=325.656 +/- 13.7557 FBGEMM, before: min=327.551, max=365.47, median=361.535, mean=359.815 +/- 4.8227 FBGEMM, now: min=318.815, max=381.853, median=363.486, mean=361.386 +/- 11.5182

m=3456 n=6144 k=4096 cuBLAS: min=112.512, max=133.856, median=131.935, mean=128.974 +/- 6.02215 FBGEMM, before: min=157.632, max=165.151, median=161.407, mean=161.407 +/- 1.22362 FBGEMM, now: min=113.727, max=135.775, median=132.063, mean=129.611 +/- 5.33297

m=3456 n=4096 k=6144 cuBLAS: min=117.023, max=142.847, median=138.687, mean=134.55 +/- 8.34265 FBGEMM, before: min=154.11, max=163.232, median=160.383, mean=159.458 +/- 2.28465 FBGEMM, now: min=119.904, max=143.198, median=140.447, mean=137.979 +/- 5.44675

m=3456 n=4096 k=4096 cuBLAS: min=78.271, max=91.583, median=89.279, mean=86.9895 +/- 4.35227 FBGEMM, before: min=102.655, max=108.255, median=106.079, mean=105.866 +/- 1.09653 FBGEMM, now: min=78.912, max=94.559, median=91.583, mean=88.8953 +/- 4.85437

m=3456 n=12288 k=4096 cuBLAS: min=218.591, max=262.335, median=255.646, mean=252.38 +/- 10.1951 FBGEMM, before: min=302.783, max=319.998, median=313.662, mean=313.112 +/- 3.32206 FBGEMM, now: min=226.654, max=270.59, median=264.734, mean=260.988 +/- 8.79772

m=3456 n=4096 k=12288 cuBLAS: min=249.406, max=297.022, median=285.151, mean=283.93 +/- 6.82518 FBGEMM, before: min=305.982, max=346.558, median=338.015, mean=335.894 +/- 7.76891 FBGEMM, now: min=246.75, max=287.87, median=282.91, mean=280.271 +/- 8.47942

m=4176 n=6144 k=4096 cuBLAS: min=133.151, max=160.224, median=156.543, mean=153.465 +/- 7.28038 FBGEMM, before: min=187.071, max=194.399, median=191.967, mean=191.66 +/- 1.28559 FBGEMM, now: min=135.295, max=163.742, median=158.719, mean=155.969 +/- 7.42937

m=4176 n=4096 k=6144 cuBLAS: min=138.367, max=171.231, median=167.487, mean=164.666 +/- 8.29714 FBGEMM, before: min=165.407, max=187.135, median=183.968, mean=182.002 +/- 4.73605 FBGEMM, now: min=164.638, max=185.471, median=180.542, mean=178.779 +/- 4.95891

https://pxl.cl/566wB

This is the code I used to get the above numbers:

for m, n, k in mnks[:]:
    print(f"{m=} {n=} {k=}")
    a = torch.randn((m, k), device="cuda").to(torch.float8_e4m3fn)
    b = torch.randn((n, k), device="cuda").to(torch.float8_e4m3fn)
    scale_a = torch.randn((m,), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((n,), device="cuda", dtype=torch.float32)
    torch._scaled_mm(a, b.t(), scale_a=scale_a[0], scale_b=scale_b[0], out_dtype=torch.bfloat16, use_fast_accum=True)
    torch.ops.fbgemm.f8f8bf16_rowwise(a, b, scale_a, scale_b, use_fast_accum=True)
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        for _ in range(1000):
            torch._scaled_mm(a, b.t(), scale_a=scale_a[0], scale_b=scale_b[0], out_dtype=torch.bfloat16, use_fast_accum=True)
            torch.ops.fbgemm.f8f8bf16_rowwise(a, b, scale_a, scale_b, use_fast_accum=True)
    stats = {}
    for event in prof.events():
        if event.cuda_time > 0:
            stats.setdefault(event.key, []).append(event.cuda_time)
    for key, times in stats.items():
        times.sort()
        t = torch.tensor(times)
        std, mean = torch.std_mean(t)
        print(f"{key[:100]}: min={times[0]:g}, max={times[-1]:g}, median={times[len(times)//2]:g}, mean={mean:g} +/- {std:g}")

I ran the benchmarks on devgpu002.eag5 which has 700W 80GB H100 GPUs.

Differential Revision: D58821928

netlify[bot] commented 3 weeks ago

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
Latest commit d4dfa64c9267b8c8db3f24c9b6bed779c3cbd68c
Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/66755ee109af390008bc8500
Deploy Preview https://deploy-preview-2764--pytorch-fbgemm-docs.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

facebook-github-bot commented 3 weeks ago

This pull request was exported from Phabricator. Differential Revision: D58821928