Open feiyuvl opened 7 months ago
I was having trouble reproing the int8 speedup. didn't look into the generated code to verify, but turns out I needed the following
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
can you try that?
Yes, you need to add coordinate_descent_tuning
to be True.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
D = 8192
def bench(f, name=None):
import time
from triton.testing import do_bench
us_per_iter = do_bench(lambda: f())*1000
print(f"{name}: {(1e6/us_per_iter) * D * D / 1e9} GB/s")
return 0
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
self.register_buffer("scales", torch.randn(out_features, dtype=dtype))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
linear = WeightOnlyInt8Linear(D, D)
compiled_linear = torch.compile(linear.eval(), fullgraph=True)
input = torch.randn(1, D, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
bench(lambda: linear(input), "eager")
bench(lambda: compiled_linear(input), "compiled")
eager: 163.15397820521537 GB/s
compiled: 1277.1920906791986 GB/s
and this is the generated file
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_chilli/26/c26sd3tlnwesrvfxag3lrdmeofbgukgfrbuj2lus7rhhd7madjg6.py
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
# linear => convert_element_type_3, mul, sum_1
# mul => mul_1
triton_red_fused_mm_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from triton.compiler.compiler import AttrsDescriptor
@reduction(
size_hints=[8192, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*i8', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4, 5))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mm_mul_0', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1 + (8192*x0)), None, eviction_policy='evict_first')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp1 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tmp8
tmp7 = tl.sum(_tmp7, 1)[:, None]
tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp9 = tmp7.to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 * tmp11
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (8192, 8192), (8192, 1))
assert_size_stride(arg1_1, (8192, ), (1, ))
assert_size_stride(arg2_1, (1, 8192), (8192, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 8192), (8192, 1), torch.float32)
buf1 = buf0; del buf0 # reuse
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused_mm_mul_0.run(buf1, arg2_1, arg0_1, arg1_1, 8192, 8192, grid=grid(8192), stream=stream0)
del arg0_1
del arg1_1
del arg2_1
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8192, 8192), (8192, 1), device='cuda:0', dtype=torch.int8)
arg1_1 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((1, 8192), (8192, 1), device='cuda:0', dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
I think we can actually relax coordinate_descent_tuning
, although we still need BS=1 restriction.
@Chillee @conway-abacus Thank you, coordinate_descent_tuning=True
generates the expected code.
https://github.com/pytorch/pytorch/pull/120954
This PR always turns on the decomposition.
hey @Chillee any tips for generating fused kernel for BS > 1? is it related at all to https://github.com/pytorch/pytorch/issues/127056
I write a simple test to get the triton code of
WeightOnlyInt8Linear
,the test code is as follows:I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code:
Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?