Closed jansel closed 4 months ago
Investigating
As an aside: Halide indexing is col major, which is the convention in math for functions (not matrices), so you may want to be saying tmp[xindex, yindex]. We use this convention because it's familiar to people used to graphics shaders, and because Halide Funcs are intended to be thought of as functions over an infinite integer domain, not matrices.
Thanks!
The yindex/xindex naming is artifact of torch.compile
's Triton codegen and how we map to GPU grids there. I'll swap it for Halide output, though that shouldn't matter in this example since both the input and output are 1D.
I raised it because it would be a source of inefficiency if tmp is ever compute_at somewhere.
Ok, so this is an instance of #3245. What's happening is that the extent required of the input as a function of the extent of the output is:
((extent - 1) / N + (N - 1) * 64) + 1
When N = 5794 this becomes:
(extent - 1) / 5794 + 370753
Halide then "cleverly" simplifies this to:
(extent + 2148142881) / 5794
which overflows a signed integer. Even for 5793 that's a dumb simplification, because extent being large could easily cause overflow at runtime instead. As I say in #3245, the simplifier should not be allowed to introduce new signed integer overflow, but we haven't been sufficiently strict about that because it's really convenient to pretend that our indices are infinite precision integers.
I'll try to figure out why we need this rule, and if there's an alternative
Makes sense. I am hitting this error somewhat frequently, so a fix would be very helpful! If there is a way to get 64-bit indexing that might also fix it (and help in the case where we have >2gb tensors -- which happens for large embedding tables).
In this case in_ptr0
is column major, while out_ptr0
is row major, so an efficient schedule likely involves some tiling so both the load and the store can get vectorized/coalesced. The original kernel has 7 inputs, which were a mix of layouts and broadcasting in one of the two dimensions.
For our Halide backend, I'm currently mapping input/outputs to 1D arrays since the indexing/view support in PyTorch is more flexible than Halide, which makes it easier to express things as 1D buffers plus indexing formulas. With views in PyTorch (or ops like torch.as_strided
) the same buffer can be read multiple times with different strides/dimensions, and inlining view operations in TorchInductor can create more complex indexing formulas that can't be expressed with linear strides.
I'm creating the 2D hl.Func
for the body of the kernel in cases where we would generate a tiled kernel in our Triton backend. This happens when we have stride=1
memory accesses in two or more different dimensions, so it will always be the case that at least 1 memory access will be swapped with respect to what Halide wants when we generate 2D hl.Func
s. If everything is contiguous in memory, right now we just generate a 1D hl.Func
.
I'd love to chat more about if there are better ways to do this mapping, and I'm happy to collaborate since I'm sure PyTorch will expose some areas for improvement in Halide. I had reached to @jrk about this a while back, since I worked closer with him when we were at MIT.
repro.py:
Throws the following error when run:
However, this passes if I reduce
N
by 1:In my original program
N
was larger, but I reduced it to find the error threshold. I observed this error in a much larger program generated bytorch.compile
, but minified it to find the smallest reproducer for the error.I am a bit surprised by this, because
5794*64
is nowhere close to INT32_MAX, and I am only dealing with buffers of a few MB.