Open jansel opened 4 months ago
Unfortunately I think that autoscheduler is not capable of autoscheduling that pipeline as written. The autoscheduler needs to be able to say: given some region of Func f, what region of Func g is going to be accessed? The indirect load makes that impossible, so it panics.
Whenever you index a Func with another Func, throw a clamp around the index so that there's a well-defined bounds relationship to infer (or use unsafe_promise_clamped if you know a bound a-priori and don't want to clamp at runtime). That should keep the autoscheduler happy.
Thanks, I'll switch to using clamp.
Is there a way to get halide to generate masked loads (using the hardware mask registers on GPUs)? In some cases many of the loads will be on an unused branch of a select
and masking them out would be faster than clamping and reading an unused value. I was hoping hl.BoundaryConditions.constant_exterior
would get mapped to that.
Making constant_exterior work that way is yet another thing on the todo list, ever since we added predication in the IR. It has been low priority because for the boundary condition helpers we already partition the loops into the iterations that might read out of bounds and the iterations that definitely don't and can thus just drop the boundary condition in the steady state, so the cost of the boundary condition is already basically zero.
But anyway, it's a bit awkward but the general way to get an arbitrary if statement is RDom::where. So for example this:
ImageParam in(UInt(32), 1, "in");
Var x;
Func tmp;
Expr in_bounds = x >= 0 && x < in.width();
tmp(x) = cast<uint32_t>(0);
RDom r(0, 1); // single-point RDom
r.where(in_bounds);
tmp(x) = in(unsafe_promise_clamped(x, 0, in.width() - 1)) + cast<uint32_t>(r); // dummy dependence on r
Func f;
f(x) = tmp(x);
Var xo, xi;
f.gpu_tile(x, xo, xi, 32, TailStrategy::RoundUp);
compiles to this PTX:
ld.param.u64 %rd3, [_kernel_f2_s0_v0_v2_block_id_x_param_0];
cvta.to.global.u64 %rd1, %rd3;
ld.param.u32 %r5, [_kernel_f2_s0_v0_v2_block_id_x_param_2];
ld.param.u32 %r6, [_kernel_f2_s0_v0_v2_block_id_x_param_3];
mov.u32 %r7, %ctaid.x;
mov.u32 %r9, %tid.x;
shl.b32 %r10, %r7, 5;
add.s32 %r1, %r10, %r9;
add.s32 %r11, %r1, %r5;
setp.lt.s32 %p1, %r11, 0;
setp.ge.s32 %p2, %r11, %r6;
mov.b32 %r13, 0;
or.pred %p3, %p1, %p2;
@%p3 bra $L__BB0_2;
// %bb.1: // %then_bb
ld.param.u64 %rd4, [_kernel_f2_s0_v0_v2_block_id_x_param_1];
cvta.to.global.u64 %rd5, %rd4;
ld.param.u32 %r8, [_kernel_f2_s0_v0_v2_block_id_x_param_4];
add.s32 %r12, %r1, %r8;
mul.wide.s32 %rd6, %r12, 4;
add.s64 %rd2, %rd5, %rd6;
ld.global.nc.u32 %r13, [%rd2];
$L__BB0_2: // %"2_consume_f1"
mul.wide.s32 %rd7, %r1, 4;
add.s64 %rd8, %rd1, %rd7;
st.global.u32 [%rd8], %r13;
ret;
and then this sass:
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;
/*0010*/ @!PT SHFL.IDX PT, RZ, RZ, RZ, RZ ;
/*0020*/ S2R R0, SR_CTAID.X ;
/*0030*/ IMAD.MOV.U32 R5, RZ, RZ, 0x4 ;
/*0040*/ BMOV.32.CLEAR RZ, B0 ;
/*0050*/ BSSY B0, `(.L_x_0) ;
/*0060*/ S2R R3, SR_TID.X ;
/*0070*/ LEA R0, R0, R3, 0x5 ;
/*0080*/ MOV R3, RZ ;
/*0090*/ IADD3 R2, R0.reuse, c[0x0][0x170], RZ ;
/*00a0*/ IMAD.WIDE R4, R0, R5, c[0x0][0x160] ;
/*00b0*/ ISETP.GE.AND P0, PT, R2, c[0x0][0x174], PT ;
/*00c0*/ ISETP.LT.OR P0, PT, R2, RZ, P0 ;
/*00d0*/ @P0 BRA `(.L_x_1) ;
/*00e0*/ IADD3 R2, R0, c[0x0][0x178], RZ ;
/*00f0*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ;
/*0100*/ IMAD.WIDE R2, R2, R3, c[0x0][0x168] ;
/*0110*/ LDG.E.CONSTANT.SYS R3, [R2] ;
.L_x_1:
/*0120*/ BSYNC B0 ;
.L_x_0:
/*0130*/ STG.E.SYS [R4], R3 ;
/*0140*/ EXIT ;
So it seems to want to use a branch. If I lower the cuda capability down to 6.1, we get:
/*0008*/ MOV R1, c[0x0][0x20] ;
/*0010*/ { MOV R6, RZ ;
/*0018*/ S2R R4, SR_CTAID.X }
/*0028*/ S2R R2, SR_TID.X ;
/*0030*/ ISCADD R4, R4, R2, 0x5 ;
/*0038*/ IADD R0, R4, c[0x0][0x150] ;
/*0048*/ ISETP.GE.AND P0, PT, R0, c[0x0][0x154], PT ;
/*0050*/ ISETP.LT.OR P0, PT, R0, RZ, P0 ;
/*0058*/ @!P0 IADD R2, R4, c[0x0][0x158] ;
/*0068*/ @!P0 SHR R0, R2.reuse, 0x1e ;
/*0070*/ @!P0 ISCADD R2.CC, R2, c[0x0][0x148], 0x2 ;
/*0078*/ @!P0 IADD.X R3, R0, c[0x0][0x14c] ;
/*0088*/ @!P0 LDG.E.CI R6, [R2] ;
/*0090*/ SHR R0, R4.reuse, 0x1e ;
/*0098*/ ISCADD R4.CC, R4, c[0x0][0x140], 0x2 ;
/*00a8*/ IADD.X R5, R0, c[0x0][0x144] ;
/*00b0*/ STG.E [R4], R6 ;
/*00b8*/ EXIT ;
so I assume ptxas knows what it's doing when it uses a branch instead
Actually I need to edit that, it's messed up. Stay tuned.
Edit: ok the code above is fixed.
Thanks! That is a neat trick with the size==1 rdom. I'll try mapping load masks to that.
Thanks, that works though on GPU I get warnings
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
Warning: Unhandled intrinsic call unsafe_promise_clamped
...
Those are warnings from an autoscheduler. It doesn't know how much cost to assign to that intrinsic in its cost model. I'll fix it.
(It's benign, because the default cost is trivial)
Using that method of masking loads causes different autoscheduler issues in aten.resize()
kernels:
repro.py
import halide as hl
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
def generate(g):
in_ptr0 = g.in_ptr0
out_ptr0 = g.out_ptr0
xindex = hl.Var('xindex')
tmp2 = hl.Func('tmp2')
tmp2[xindex] = xindex < 2
tmp3 = hl.Func('tmp3')
tmp3_mask = hl.RDom([hl.Range(0, 1)])
tmp3_mask.where(tmp2[xindex])
tmp3[xindex] = hl.f32(0)
tmp3[xindex] = in_ptr0[hl.unsafe_promise_clamped(xindex, 0, 1)] + hl.cast(hl.Float(32), tmp3_mask)
out_ptr0[xindex] = tmp3[xindex]
assert g.using_autoscheduler()
in_ptr0.set_estimates([hl.Range(0, 2)])
out_ptr0.set_estimates([hl.Range(0, 18)])
if __name__ == "__main__":
import sys, tempfile
with tempfile.TemporaryDirectory() as out:
sys.argv = ['repro.py',
'-g', 'kernel',
'-o', out,
'-f', 'halide_kernel',
'-e', 'static_library,h,schedule',
'-p', '/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so',
'target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts',
'autoscheduler=Anderson2021',
'autoscheduler.parallelism=82']
hl.main()
Output on CUDA+Anderson2021:
Unhandled exception: Internal Error at /home/jansel/Halide/src/autoschedulers/anderson2021/GPUMemInfo.h:70 triggered by user code at : Condition failed: total_bytes_used <= total_bytes:
total_bytes_used = -504
total_bytes = -4032
total_transactions = -126
num_transactions_per_request = 18
num_requests = -7
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 38, in <module>
hl.main()
RuntimeError: Generator failed: -1
Output on CPU+Adams2019
Unhandled exception: Internal Error at /home/jansel/Halide/src/autoschedulers/adams2019/LoopNest.cpp:863 triggered by user code at : Condition failed: task_p.min() <= task_p.max(): 16 1
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 38, in <module>
hl.main()
RuntimeError: Generator failed: -1
It works with Li2018 and Mullapudi2016 though.
This is a lowering of
torch.ops.aten.index(a, [b, c])
(indexing one tensor with the values of two other tensors).Repro.py
Output:
I get similar errors from all the schedulers except Mullapudi2016, which schedules this fine.
For Triton/C++ backends we generate device asserts that the values are in-bounds and will raise an error on out-of-bounds access. I couldn't find a way to do device asserts in Halide, so I'm using
hl.BoundaryConditions.constant_exterior(in_ptr2, 0)
which isn't exactly right.Is there a better way to tell Halide to error on out-of-bounds access that can be autoscheduled?