pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

Triton segfaults when running on torchinductor generated code #1515

Closed fdrocha closed 1 year ago

fdrocha commented 2 years ago

Came across this while working on a decomp for torch.bucketize.

Here's code for a MRE ```python import torch import torchdynamo import torchinductor import math @torchdynamo.optimize('inductor') def simple_bucketize(a, boundaries): n_boundaries = boundaries.shape[-1] start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) end = start + n_boundaries # Max depth of the binary search niters = int(math.log2(n_boundaries)) + 1 for _ in range(niters): cond_update = start < end # start might end up pointing to 1 past the end, we guard against that mid = torch.where(cond_update, start + (end - start) // 2, 0) mid_val = boundaries[mid] cond_mid = mid_val >= a start = torch.where((~cond_mid) & cond_update, mid + 1, start) end = torch.where(cond_mid & cond_update, mid, end) return start dtype = torch.float32 test_func = torch.ops.aten.bucketize device = "cuda" a = torch.randn((1,32), dtype=dtype, device=device) b = torch.randn((7,), dtype=dtype, device=device) torchinductor.config.debug = True simple_bucketize(a, b) ```

If you run this it segfaults on the last line, when running triton compilation. If you compile triton in debug mode you hit an assert:

triton/lib/ir/type.cc:74: triton::ir::type::block_shapes_t triton::ir::type::get_block_shapes() const: Assertion `is_block_ty()' failed.
This is the (slightly cleaned up) generated triton code ```python from torchinductor.codecache import AsyncCompile async_compile = AsyncCompile() import triton import triton.language as tl kernel0 = async_compile.triton(''' import triton import triton.language as tl from torchinductor.ir import ReductionHint from torchinductor.triton_ops.autotune import pointwise from torchinductor.utils import instance_descriptor @pointwise(size_hints=[32], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*i64', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}}) @triton.jit def kernel(in_ptr0, in_ptr1, out_ptr3, xnumel, XBLOCK : tl.constexpr): xnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) xmask = xindex < xnumel x0 = xindex tmp10 = tl.load(in_ptr1 + (x0), xmask) tmp0 = 0 tmp1 = 7 tmp2 = tmp0 + tmp1 tmp3 = tmp0 < tmp2 tmp4 = tmp2 - tmp0 tmp5 = 2 tmp6 = tl.where((tmp4 < 0) != (tmp5 < 0), tl.where(tmp4 % tmp5 != 0, tmp4 // tmp5 - 1, tmp4 // tmp5), tmp4 // tmp5) tmp7 = tmp0 + tmp6 tmp8 = tl.where(tmp3, tmp7, tmp0) tmp9 = tl.load(in_ptr0 + (tmp8), xmask) tmp11 = tmp9 >= tmp10 tmp12 = tmp11 == 0 tmp13 = tmp12 & tmp3 tmp14 = 1 tmp15 = tmp8 + tmp14 tmp16 = tl.where(tmp13, tmp15, tmp0) tmp17 = tmp11 & tmp3 tmp18 = tl.where(tmp17, tmp8, tmp2) tmp19 = tmp16 < tmp18 tmp20 = tmp18 - tmp16 tmp21 = tl.where((tmp20 < 0) != (tmp5 < 0), tl.where(tmp20 % tmp5 != 0, tmp20 // tmp5 - 1, tmp20 // tmp5), tmp20 // tmp5) tmp22 = tmp16 + tmp21 tmp23 = tl.where(tmp19, tmp22, tmp0) tmp24 = tl.load(in_ptr0 + (tmp23), xmask) tmp25 = tmp24 >= tmp10 tmp26 = tmp25 == 0 tmp27 = tmp26 & tmp19 tmp28 = tmp23 + tmp14 tmp29 = tl.where(tmp27, tmp28, tmp16) tmp30 = tmp25 & tmp19 tmp31 = tl.where(tmp30, tmp23, tmp18) tmp32 = tmp29 < tmp31 tmp33 = tmp31 - tmp29 tmp34 = tl.where((tmp33 < 0) != (tmp5 < 0), tl.where(tmp33 % tmp5 != 0, tmp33 // tmp5 - 1, tmp33 // tmp5), tmp33 // tmp5) tmp35 = tmp29 + tmp34 tmp36 = tl.where(tmp32, tmp35, tmp0) tmp37 = tl.load(in_ptr0 + (tmp36), xmask) tmp38 = tmp37 >= tmp10 tmp39 = tmp38 == 0 tmp40 = tmp39 & tmp32 tmp41 = tmp36 + tmp14 tmp42 = tl.where(tmp40, tmp41, tmp29) tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp42, xmask) ''') def call(arg0_1, arg1_1): arg0_1_size = arg0_1.size() s0 = arg0_1_size[1] arg1_1_size = arg1_1.size() s1 = arg1_1_size[0] buf3 = empty_strided((1, 32), (32, 1), device='cuda', dtype=torch.int64) stream0 = get_cuda_stream(0) kernel0.run(arg1_1, arg0_1, buf3, 32, grid=grid(32), stream=stream0) return (buf3, ) ```

You get the same error if you run this. One strange thing about the triton code is the xmask argument when calling tl.load(in_ptr0 + ...,). The pointer in_ptr0 corresponds to tensor b, and xmask is a mask for tensor a which has a completely different shape. It seems like it could be a torchinductor bug that xmask appears in these loads. Note also that all the tl.load(in_ptr0 + ...,) calls are for a single element, while xmask is a block.

In fact, if you manually remove the xmask argument from the tl.load(in_ptr0 + ...) calls, the code will no longer segfault.

@jansel @ngimel @lezcano

voznesenskym commented 2 years ago

Maybe not perfectly related to this issue, but we have a known list of segfaulting ops in tests. We should make use of that in inductor itself. https://github.com/pytorch/pytorch/issues/93712

fdrocha commented 2 years ago

@voznesenskym could you point me to the list of segfaulting ops?

lezcano commented 2 years ago

I think there's none yet, but you can see in https://github.com/pytorch/torchdynamo/pull/1522/files that voz is marking there a few ops as segfaulting. His linked issue suggests it would be good to have one.

Now, I think that this segfault should not stop you from having your decomposition merged in core.

voznesenskym commented 2 years ago

@voznesenskym could you point me to the list of segfaulting ops?

https://github.com/pytorch/torchdynamo/blob/main/test/inductor/test_torchinductor_opinfo.py

Closest thing we have. As @lezcano correctly said - its not a real list in inductor yet, more of a test exclusion list.

We are working to burn this down actively.

lezcano commented 2 years ago

If you know of any that could be tagged as "good first issue", please tell us. @fdrocha and I would be happy to help :)

ngimel commented 2 years ago

This is a real codegen bug, the problematic line is

tmp9 = tl.load(in_ptr0 + (tmp8), xmask))

Here tmp8 is computed from constant and is a constant (single number) itself, but mask is a triton tensor, and triton expects offsets and mask to have the same shape. Unfortunately, our codegen doesn't distinguish well between cases where offsets are result of indirect indexing (and there tmp8 would be a triton tensor of proper shape), and this, where it's just a number, and the fix should be to teach codegen to distinguish between these situations, so that either offsets are of proper shape, or mask is not added. The code in question is around here https://github.com/pytorch/torchdynamo/blob/5ca06364417c06f0fe80e38d4548080ff9a6dff5/torchinductor/codegen/triton.py#L684-L697, is_indirect_indexing is a very simple function that just tests for tmp in the name, and perhaps we need to generate smarter variable names? @fdrocha let me know if you are interested in looking into this.
to @voznesenskym's point, we can block existing ops from being codegened, but the above is legal code that can be seen in users programs, and unless we fix our codegen bugs we can't prevent users from writing code like this and ultimately segfaulting, so I think issues like this merit fixes, and not just blocking original ops from codegen.

fdrocha commented 2 years ago

Thanks for the comments @ngimel. I'd actually started looking into this today, your observations are very helpful.

I think it's more than just figuring out if it's a tensor or a number though. Take the following example:

def test(x, y):
    return x[y]

And you call with with x of shape (S1,) and y of shape (S2,). The generated triton looks like this:

def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (tmp0), xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)

Note that this is still wrong even though tmp0 is now a tensor. It's wrong to use the same mask for in_ptr0 and in_ptr1. For instance, if y=tensor([2**20]) and S1=5 you are accessing invalid memory and you'll get a CUDA error.

One simple solution is to just never use a mask for indirect indexing, and make it the caller's responsibility to make sure the indices are all in bounds. Does that sound reasonable?

ngimel commented 2 years ago

Yeah we were just discussing possible invalid values in input with @jansel, and it's ok to ignore for now (assume that all values in y are valid). Going forward, we'd need to error out in the kernel (that's today's eager behavior), not just mask invalid values. With the assumption of valid values in y, using the same mask is ok

desertfire commented 2 years ago

https://github.com/pytorch/torchdynamo/blob/5ca06364417c06f0fe80e38d4548080ff9a6dff5/torchinductor/codegen/triton.py#L684-L697

, is_indirect_indexing is a very simple function that just tests for tmp in the name, and perhaps we need to generate smarter variable names? @fdrocha let me know if you are interested in looking into this.

@fdrocha , to follow up on Natalia's suggestion here, are you planning to work on it?

fdrocha commented 2 years ago

Yes, I've started looking into this.