triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
11.86k stars 1.4k forks source link

triton generates unnecessary shared memory stores/loads #3491

Open isuruf opened 3 months ago

isuruf commented 3 months ago

For the following triton kernels generated by pytorch, triton generated shared memory stores and loads in the LLVM IR and PTX just before the atomic add operation.

```python from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() # kernel path: /tmp/torchinductor_isuruf/5e/c5ehw64oxeoeqqjnqn6v3gfy6z5ukksktwihp7jgzg6sujz5umto.py # Source Nodes: [], Original ATen: [] triton_poi_fused_0 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor @triton_heuristics.pointwise( size_hints=[16777216], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'}, min_elem_per_thread=0 ) @triton.jit def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 8750000 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = 0.0 tl.store(out_ptr0 + (x0), tmp0, xmask) ''', device_str='cuda') import triton import triton.language as tl from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph from torch._C import _cuda_getCurrentRawStream as get_raw_stream # kernel path: /tmp/torchinductor_isuruf/qg/cqgmsmgdzivumf2gmksclwbmyrwpfpouuv3s5suqkeg4j4cdmpjr.py # Source Nodes: [], Original ATen: [] triton_poi_fused_1 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor @triton_heuristics.pointwise( size_hints=[67108864], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'}, min_elem_per_thread=0 ) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 35000000 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x1 = (xindex // 1000) % 1000 x0 = xindex % 1000 x3 = xindex x2 = (xindex // 1000000) tmp22 = tl.load(in_ptr0 + (x3), xmask) tmp0 = x1 tmp1 = tmp0.to(tl.float32) tmp2 = 0.5 tmp3 = tmp1 + tmp2 tmp4 = tmp3 * tmp2 tmp5 = tmp4 - tmp2 tmp6 = tmp5.to(tl.int32) tmp7 = x0 tmp8 = tmp7.to(tl.float32) tmp9 = tmp8 + tmp2 tmp10 = tmp9 * tmp2 tmp11 = tmp10 - tmp2 tmp12 = tmp11.to(tl.int32) tmp13 = tmp6.to(tl.float32) tmp14 = tmp5 - tmp13 tmp15 = 1.0 tmp16 = tmp15 - tmp14 tmp17 = tmp15 * tmp16 tmp18 = tmp12.to(tl.float32) tmp19 = tmp11 - tmp18 tmp20 = tmp15 - tmp19 tmp21 = tmp17 * tmp20 tmp23 = tmp21 * tmp22 tmp24 = tl.full([1], 1, tl.int32) tmp25 = tmp12 + tmp24 tmp26 = tl.full([1], 499, tl.int32) tmp27 = triton_helpers.minimum(tmp25, tmp26) tmp28 = tmp17 * tmp19 tmp29 = tmp28 * tmp22 tmp30 = tmp6 + tmp24 tmp31 = triton_helpers.minimum(tmp30, tmp26) tmp32 = tmp15 * tmp14 tmp33 = tmp32 * tmp20 tmp34 = tmp33 * tmp22 tmp35 = tmp32 * tmp19 tmp36 = tmp35 * tmp22 tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (250000*x2)), tmp23, xmask) tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp6) + (250000*x2)), tmp29, xmask) tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp31) + (250000*x2)), tmp34, xmask) tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp31) + (250000*x2)), tmp36, xmask) ''', device_str='cuda') async_compile.wait(globals()) del async_compile def call(args): args_1, = args args.clear() assert_size_stride(args_1, (7, 5, 1000, 1000), (5000000, 1000000, 1000, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((7, 5, 500, 500), (1250000, 250000, 500, 1), torch.float32) # Source Nodes: [], Original ATen: [] stream0 = get_raw_stream(0) triton_poi_fused_0.run(buf0, 8750000, grid=grid(8750000), stream=stream0) # Source Nodes: [], Original ATen: [] triton_poi_fused_1.run(args_1, buf0, 35000000, grid=grid(35000000), stream=stream0) del args_1 return (buf0, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance args_1 = rand_strided((7, 5, 1000, 1000), (5000000, 1000000, 1000, 1), device='cuda:0', dtype=torch.float32) fn = lambda: call([args_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module) ```

Shared memory loads/stores are unnecessary in this case. cc @peterbell10

isuruf commented 3 months ago

Based on a suggestion from @peterbell10 I removed AtomicRMWOp at https://github.com/openai/triton/blob/0ba87e2ff35f703f84040400554702ee55476cdb/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#L192 which resulted in the PTX not having any shared memory loads/stores. This resulted in the triton generated kernel to match the pytorch eager backend code whereas it was 50% slower previously with the shared stores and loads.

isuruf commented 3 months ago

Is there a case where removing AtomicRMWOp as a layout anchor can result in incorrect code?

manman-ren commented 3 months ago

I don't think it will result in incorrect code, but I may be wrong. It can affect performance, so will likely need to go through benchmark suites to verify performance impact. Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

isuruf commented 3 months ago

I'm using pytorch v2.3.0-rc6

peterbell10 commented 3 months ago

Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

Given that graphsafe_set_state doesn't appear in the generated code, you probably just need to rebuild pytorch.

manman-ren commented 3 months ago

You are right. I thought I built it after the source pull.

manman-ren commented 3 months ago

I looked at this, but not sure what is the best solution :] Instead, I noticed a few things which I will try to figure out why. 1> It is not clear to me why the atomic op has a different layout sizePerThread = [1] (sizePerThread = [4] for the load op). 2> why the atomic op is an anchor for remove-layout 3> With sizePerThread = [1] and sizePerThread = [4], at ptx level, the atomic op uses the same instruction 8 times atom.global.gpu.acq_rel.add.f32. For the first case, there are two different predicates, but for the latter, it has one predicate. So it looks like sizePerThread=[4] is more efficent?

lezcano commented 3 months ago

cc @ThomasRaoux @Jokeren for visibility.