Open phlrain opened 1 year ago
I do think we can optimize triton_
internally, but I'm confused about the testing result.
Are they the same shape?
print(call([arg0_1])[0].shape)
print(torch.sum(arg0_1, axis = 0).shape)
The slowness might be caused by the fact that reduce isn't performed on the fast changing dimension.
Before deriving a conclusion, I think we should have a valid test first.
I do think we can optimize
triton_
internally, but I'm confused about the testing result.Are they the same shape?
print(call([arg0_1])[0].shape) print(torch.sum(arg0_1, axis = 0).shape)
My mistake. The right code is
torch.sum( arg0_1, axis = [0, 1, 2], keepdim=True)
It's about 2x times faster than triton kernel.
Thanks, will take a look
@phlrain May I know how triton_fused_sum_1_0
is generated? I assume you compiled a pytorch model and extract the code here, right? I asked because I want to have an end-to-end test.
@Jokeren The code is here:
import torch
import numpy as np
def test_f( x ):
return torch.sum( x, axis = [0, 1, 2], keepdim=False)
x = torch.randn( [256, 14, 14, 256]).cuda()
compile_f = torch.compile( test_f )
out = compile_f( x )
But after update pytorch from 1.14.0.dev20221204+cu117 to 2.0.0.dev20230201+cu117; Triton is Faster than torch native kernel.
triton 0.000960
torch 0.001460
I compare the code generator by torch inductor. The new version using two kernel to speedup.
It shows Triton have powerful capability.
Actually I am testing batch norm in NHWC memory format. I decomposite batch norm by myself (not including running mean and variance update)
import torch
import numpy as np
import time
def composite_batchnorm(
x,
run_mean,
variance,
scale,
bias,
is_test = False,
momentum = 0.9,
epsilon=1e-5,
data_layout="NCHW",
use_global_stats=None,
trainable_statistics=None,
):
n, c, h, w = x.shape
mean = torch.mean( x, [0,1, 2], keepdim=True )
pow_mean = torch.mean( x*x, [0, 1, 2], keepdim=True )
var2 = x - mean
var1 = pow_mean - mean*mean
var1 = torch.sqrt( var1 + epsilon)
t1 = var2 / var1
y = t1 * scale + bias
return y
input = torch.randn( [256, 256, 14, 14]).cuda().to(memory_format=torch.channels_last) # we need test nhwc
x = torch.randn( [256, 14, 14, 256]).cuda()
running_mean = torch.ones( [256] ).cuda()
running_variance = torch.ones( [256] ).cuda()
weight = torch.ones( [256] ).cuda()
bias = torch.ones( [256] ).cuda()
compile_batchnorm = torch.compile( composite_batchnorm )
out = compile_batchnorm( x,running_mean, running_variance, weight, bias )
from torch._inductor.utils import print_performance
print_performance( lambda : compile_batchnorm( x,running_mean, running_variance, weight, bias ) )
print_performance( lambda : torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias) )
The result is:
triton 0.015794
torch 0.002325
It seems the calculate of mean and variance is slower than the torch native kernel named "batch_norm_collect_statistics_channels_last_kernel". Is there any way to speed up this ?
Thanks for the code. It will be useful
Note that a comparable test in master using the decompositions in core yields the same performance as the eager function:
import torch
import numpy as np
import time
def composite_batchnorm(input, running_mean, running_variance, weight, bias):
return torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias)
input = torch.randn( [256, 256, 14, 14]).cuda().to(memory_format=torch.channels_last) # we need test nhwc
x = torch.randn( [256, 14, 14, 256]).cuda()
running_mean = torch.ones( [256] ).cuda()
running_variance = torch.ones( [256] ).cuda()
weight = torch.ones( [256] ).cuda()
bias = torch.ones( [256] ).cuda()
compile_batchnorm = torch.compile( composite_batchnorm )
from torch._inductor.utils import print_performance
print_performance( lambda : compile_batchnorm( input, running_mean, running_variance, weight, bias) )
print_performance( lambda : torch.nn.functional.batch_norm( input, running_mean, running_variance, weight, bias) )
returns
0.003762
0.003755
Also, I think the issue the repro in https://github.com/openai/triton/issues/1172#issuecomment-1429335406 was that one implementation was being tested with a contiguous input input
and the other one was being tested with a channels_last input x
.
I am testing some case in torch inductor, for example, a reduce_sum operation with
input shape = [256, 14, 14, 256], reduce_axis = [0,1,2]
The code bellow is generator by torch inductor.
Torch native kernel is about 2x faster than triton kernel.
I trying to analysis the triton code; In each program, the loaded data is not contiguous; Is there any way to speedup the triton code