pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.57k stars 507 forks source link

Question about the gennerated code of `WeightOnlyInt8Linear` #114

Open feiyuvl opened 7 months ago

feiyuvl commented 7 months ago

I write a simple test to get the triton code of WeightOnlyInt8Linear,the test code is as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
        self.register_buffer("scales", torch.randn(out_features, dtype=torch.bfloat16))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

linear = WeightOnlyInt8Linear(4096, 8192)
linear.to(device='cuda', dtype=torch.bfloat16)
linear = torch.compile(linear.eval(), mode='reduce-overhead', fullgraph=True)

input = torch.randn(1, 4096, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
    linear(input)

I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code: 1709173980439

Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?

conway-abacus commented 7 months ago

I was having trouble reproing the int8 speedup. didn't look into the generated code to verify, but turns out I needed the following

import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True

can you try that?

Chillee commented 7 months ago

Yes, you need to add coordinate_descent_tuning to be True.

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True

D = 8192

def bench(f, name=None):
    import time
    from triton.testing import do_bench

    us_per_iter = do_bench(lambda: f())*1000
    print(f"{name}: {(1e6/us_per_iter) * D * D / 1e9} GB/s")

    return 0

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
        self.register_buffer("scales", torch.randn(out_features, dtype=dtype))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

linear = WeightOnlyInt8Linear(D, D)
compiled_linear = torch.compile(linear.eval(), fullgraph=True)

input = torch.randn(1, D, device='cuda').to(dtype=torch.bfloat16)

with torch.no_grad():
    bench(lambda: linear(input), "eager")
    bench(lambda: compiled_linear(input), "compiled")
eager: 163.15397820521537 GB/s
compiled: 1277.1920906791986 GB/s

and this is the generated file

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()

# kernel path: /tmp/torchinductor_chilli/26/c26sd3tlnwesrvfxag3lrdmeofbgukgfrbuj2lus7rhhd7madjg6.py
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
# linear => convert_element_type_3, mul, sum_1
# mul => mul_1
triton_red_fused_mm_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from triton.compiler.compiler import AttrsDescriptor

@reduction(
    size_hints=[8192, 8192],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*i8', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4, 5))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mm_mul_0', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 8192
    rnumel = 8192
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1), None, eviction_policy='evict_last').to(tl.float32)
        tmp2 = tl.load(in_ptr1 + (r1 + (8192*x0)), None, eviction_policy='evict_first')
        tmp1 = tmp0.to(tl.float32)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp3.to(tl.float32)
        tmp5 = tmp1 * tmp4
        tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
        tmp8 = _tmp7 + tmp6
        _tmp7 = tmp8
    tmp7 = tl.sum(_tmp7, 1)[:, None]
    tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
    tmp9 = tmp7.to(tl.float32)
    tmp10 = tmp9.to(tl.float32)
    tmp12 = tmp10 * tmp11
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (8192, 8192), (8192, 1))
    assert_size_stride(arg1_1, (8192, ), (1, ))
    assert_size_stride(arg2_1, (1, 8192), (8192, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, 8192), (8192, 1), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused_mm_mul_0.run(buf1, arg2_1, arg0_1, arg1_1, 8192, 8192, grid=grid(8192), stream=stream0)
        del arg0_1
        del arg1_1
        del arg2_1
    return (buf1, )

def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((8192, 8192), (8192, 1), device='cuda:0', dtype=torch.int8)
    arg1_1 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg2_1 = rand_strided((1, 8192), (8192, 1), device='cuda:0', dtype=torch.bfloat16)
    fn = lambda: call([arg0_1, arg1_1, arg2_1])
    return print_performance(fn, times=times, repeat=repeat)

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
Chillee commented 7 months ago

I think we can actually relax coordinate_descent_tuning, although we still need BS=1 restriction.

feiyuvl commented 7 months ago

@Chillee @conway-abacus Thank you, coordinate_descent_tuning=True generates the expected code.

Chillee commented 7 months ago

https://github.com/pytorch/pytorch/pull/120954

This PR always turns on the decomposition.

conway-abacus commented 1 month ago

hey @Chillee any tips for generating fused kernel for BS > 1? is it related at all to https://github.com/pytorch/pytorch/issues/127056