Open jansel opened 3 weeks ago
I've investigated, and I believe this can't really be fixed without making existing code worse unless we fix #3245 entirely. The specific issue is:
x2 = xindex // 152960
Given #3245, this sort of code in indexing expressions is going to cause problems. Halide tries to solve equations involving these constants to do things like loop partitioning, and in this case 152960 gets multiplied by 30592 at some point, which then overflows so the compiler essentially panics. A workaround is to use Params instead of int literals for your tensor shapes to avoid all these large constants, but it's an ugly workaround that doesn't fix #3245 and doesn't solve the other big problem here:
This 1D indexing is terrible for performance. It adds expensive divs and mods, can break vectorization (because dense loads become vector gathers, makes the autoscheduler useless (because there's a single specific scheduling idiom you need to apply to remove the div and mod), makes large_buffers useless (because large_buffers means each dim is still 31-bit, but the product of the dims may be 63-bit), and generally makes a mess in the IR. I strongly recommend not using 1D inputs and outputs.
You said earlier that you're doing it because there are views that pytorch supports that can't be expressed with affine indexing. Can you elaborate?
I'll write a pass to convert to dimensions where possible. It is possible often, but not always.
A good example to illustrate the issue is torch.as_strided(), which take the memory of an existing tensor and reinterprets it as a new tensor with a provided list of sizes/strides, possibly permuted in any order with respect to memory order or with a different number of dimensions.
There are lots of other view ops (view/reshape/permute/transpose/slice/index/etc), but all of them can be mapped to as_strided
.
When using dimensions, it was hard to map as_strided
to Halide. Maybe there is some op I was missing. A single sizes/strides per input can be represented, but not multiple or changing them mid-program.
This issue is the same error as #8227, but it is not fixed by #8234.
I am running https://github.com/halide/Halide/commit/340136fec6d3ebc73e7a19eba1663e9b0ba8ab2d -- which includes #8234. Unlike #8227 this one only triggers when an autoscheduler is used.
Output: