triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.76k stars 1.54k forks source link

triton reduce_sum is much slower than torch native kernel #1172

Open phlrain opened 1 year ago

phlrain commented 1 year ago

I am testing some case in torch inductor, for example, a reduce_sum operation with

input shape = [256, 14, 14, 256], reduce_axis = [0,1,2]

The code bellow is generator by torch inductor.

Torch native kernel is about 2x faster than triton kernel.

I trying to analysis the triton code; In each program, the loaded data is not contiguous; Is there any way to speedup the triton code

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

triton_fused_sum_1_0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor

@reduction(size_hints=[256, 65536],
              reduction_hint=ReductionHint.INNER,
              filename=__file__,
              meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': set(), 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 256
    rnumel = 50176
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
    xmask = xindex < xnumel
    rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
    x0 = xindex
    _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (256*r1)), xmask & rmask, eviction_policy='evict_last')
        _tmp1 = tl.where(xmask & rmask, _tmp1 + tmp0, _tmp1)
    tmp1 = tl.reshape(tl.sum(_tmp1, 1), [XBLOCK, 1])
    tl.store(out_ptr0 + x0, tmp1, xmask)
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    buf0 = empty_strided((1, 1, 1, 256), (256, 256, 256, 1), device='cuda', dtype=torch.float32)
    stream0 = get_cuda_stream(0)
    triton_fused_sum_1_0.run(arg0_1, buf0, 256, 50176, grid=grid(256), stream=stream0)
    del arg0_1
    return (buf0, )

if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((256, 14, 14, 256), (50176, 3584, 256, 1), device='cuda', dtype=torch.float32)
    print_performance(lambda: call([arg0_1]))                                  #  result is 0.002817
    print_performance( lambda : torch.sum( arg0_1, axis = [0, 1,2], keepdim=True))        #  result is 0.001500
Jokeren commented 1 year ago

I do think we can optimize triton_ internally, but I'm confused about the testing result.

Are they the same shape?

    print(call([arg0_1])[0].shape)
    print(torch.sum(arg0_1, axis = 0).shape)
Jokeren commented 1 year ago

The slowness might be caused by the fact that reduce isn't performed on the fast changing dimension.

Before deriving a conclusion, I think we should have a valid test first.

phlrain commented 1 year ago

I do think we can optimize triton_ internally, but I'm confused about the testing result.

Are they the same shape?

    print(call([arg0_1])[0].shape)
    print(torch.sum(arg0_1, axis = 0).shape)

My mistake. The right code is

torch.sum( arg0_1, axis = [0, 1, 2], keepdim=True)

It's about 2x times faster than triton kernel.

Jokeren commented 1 year ago

Thanks, will take a look

Jokeren commented 1 year ago

@phlrain May I know how triton_fused_sum_1_0 is generated? I assume you compiled a pytorch model and extract the code here, right? I asked because I want to have an end-to-end test.

phlrain commented 1 year ago

@Jokeren The code is here:

import torch
import numpy as np

def test_f( x ):
    return torch.sum( x, axis = [0, 1, 2], keepdim=False)

x = torch.randn( [256, 14, 14, 256]).cuda()

compile_f = torch.compile( test_f )
out = compile_f( x )

But after update pytorch from 1.14.0.dev20221204+cu117 to 2.0.0.dev20230201+cu117; Triton is Faster than torch native kernel.

triton    0.000960
torch    0.001460

I compare the code generator by torch inductor. The new version using two kernel to speedup.

It shows Triton have powerful capability.

Actually I am testing batch norm in NHWC memory format. I decomposite batch norm by myself (not including running mean and variance update)

import torch
import numpy as np
import time

def composite_batchnorm(
    x,
    run_mean,
    variance,
    scale,
    bias,
    is_test = False,
    momentum = 0.9,
    epsilon=1e-5,
    data_layout="NCHW",
    use_global_stats=None,
    trainable_statistics=None,
):
    n, c, h, w = x.shape

    mean = torch.mean( x, [0,1, 2], keepdim=True )
    pow_mean = torch.mean( x*x, [0, 1, 2], keepdim=True )

    var2 = x - mean
    var1 = pow_mean - mean*mean

    var1 = torch.sqrt( var1 + epsilon)

    t1 = var2 / var1

    y = t1 * scale + bias
    return y

input = torch.randn( [256, 256, 14, 14]).cuda().to(memory_format=torch.channels_last)   # we need test nhwc

x = torch.randn( [256, 14, 14, 256]).cuda()
running_mean = torch.ones( [256] ).cuda()
running_variance = torch.ones( [256] ).cuda()
weight = torch.ones( [256] ).cuda()
bias = torch.ones( [256] ).cuda()

compile_batchnorm = torch.compile( composite_batchnorm )

out = compile_batchnorm( x,running_mean, running_variance, weight, bias  )

from torch._inductor.utils import print_performance

print_performance( lambda : compile_batchnorm( x,running_mean, running_variance, weight, bias  ) )
print_performance( lambda : torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias) )

The result is:

triton  0.015794
torch  0.002325

It seems the calculate of mean and variance is slower than the torch native kernel named "batch_norm_collect_statistics_channels_last_kernel". Is there any way to speed up this ?

Jokeren commented 1 year ago

Thanks for the code. It will be useful

lezcano commented 1 year ago

Note that a comparable test in master using the decompositions in core yields the same performance as the eager function:

import torch
import numpy as np
import time

def composite_batchnorm(input, running_mean, running_variance, weight, bias):
    return torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias)

input = torch.randn( [256, 256, 14, 14]).cuda().to(memory_format=torch.channels_last)   # we need test nhwc

x = torch.randn( [256, 14, 14, 256]).cuda()
running_mean = torch.ones( [256] ).cuda()
running_variance = torch.ones( [256] ).cuda()
weight = torch.ones( [256] ).cuda()
bias = torch.ones( [256] ).cuda()

compile_batchnorm = torch.compile( composite_batchnorm )

from torch._inductor.utils import print_performance

print_performance( lambda : compile_batchnorm( input, running_mean, running_variance, weight, bias) )
print_performance( lambda : torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias) )

returns

0.003762
0.003755
lezcano commented 1 year ago

Also, I think the issue the repro in https://github.com/openai/triton/issues/1172#issuecomment-1429335406 was that one implementation was being tested with a contiguous input input and the other one was being tested with a channels_last input x.