mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
603 stars 45 forks source link

Unsupported `scatter_reduce` operators in Cuda and LLVM backend #176

Closed WeiPhil closed 1 year ago

WeiPhil commented 1 year ago

Hi, I've come across two potential issues when performing scatter_reduce with the dr.ReduceOp.Max or dr.ReduceOp.Min operators and the cuda backend, here is a minimal reproducer:

import drjit as dr

from drjit.cuda import Float, UInt32
# from drjit.llvm import Float, UInt32

shape = 8
a = dr.zeros(Float, shape=shape)
b = dr.linspace(Float, start=1.0, stop=8.0, num=shape)

print("a", a)
print("b", b)
idx = dr.arange(UInt32, 0, shape)
print("idx", idx)

dr.scatter_reduce(dr.ReduceOp.Max, a, b, idx)
# dr.scatter_reduce(dr.ReduceOp.Min, a, b, idx)
print("result", a)

With the LLVM backend this prints the expected result

a [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
b [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
idx [0, 1, 2, 3, 4, 5, 6, 7]
result [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]

but with the cuda backend I get the following error:

Critical Dr.Jit compiler failure: jit_cuda_compile(): compilation failed. Please see the PTX assembly listing and error message below:

.version 7.1
.target sm_86
.address_size 64

.entry drjit_56971614ee2ebd5608102657da246170(.param .align 8 .b8 params[32]) {
    .reg.b8   %b <9>; .reg.b16 %w<9>; .reg.b32 %r<9>;
    .reg.b64  %rd<9>; .reg.f32 %f<9>; .reg.f64 %d<9>;
    .reg.pred %p <9>;

    mov.u32 %r0, %ctaid.x;
    mov.u32 %r1, %ntid.x;
    mov.u32 %r2, %tid.x;
    mad.lo.u32 %r0, %r0, %r1, %r2;
    ld.param.u32 %r2, [params];
    setp.ge.u32 %p0, %r0, %r2;
    @%p0 bra done;

    mov.u32 %r3, %nctaid.x;
    mul.lo.u32 %r1, %r3, %r1;

body: // sm_86
    ld.param.u64 %rd4, [params+8];
    ld.param.u64 %rd0, [params+16];
    mad.wide.u32 %rd0, %r0, 4, %rd0;
    ld.global.cs.f32 %f5, [%rd0];
    ld.param.u64 %rd0, [params+24];
    mad.wide.u32 %rd0, %r0, 4, %rd0;
    ld.global.cs.u32 %r6, [%rd0];
    mov.pred %p7, 0x1;
    mad.wide.u32 %rd3, %r6, 4, %rd4;
    {
        .visible .func reduce_max_f32(.param .u64 ptr, .param .f32 value);
        call reduce_max_f32, (%rd3, %f5);
    }

    add.u32 %r0, %r0, %r1;
    setp.ge.u32 %p0, %r0, %r2;
    @!%p0 bra body;

done:
    ret;
}

.visible .func reduce_max_f32(.param .u64 ptr,
                              .param .f32 value) {
    .reg .pred %p<14>;
    .reg .f32 %q<19>;
    .reg .b32 %r<41>;
    .reg .b64 %rd<2>;

    ld.param.u64 %rd0, [ptr];
    ld.param.f32 %q3, [value];
    activemask.b32 %r1;
    match.any.sync.b64 %r2, %rd0, %r1;
    setp.eq.s32 %p1, %r2, -1;
    @%p1 bra.uni fast_path;

    brev.b32 %r10, %r2;
    bfind.shiftamt.u32 %r40, %r10;
    shf.l.wrap.b32 %r12, -2, -2, %r40;
    and.b32 %r39, %r2, %r12;
    setp.ne.s32 %p2, %r39, 0;
    vote.sync.any.pred %p3, %p2, %r1;
    @!%p3 bra maybe_scatter;
    mov.b32 %r5, %q3;

slow_path_repeat:
    brev.b32 %r14, %r39;
    bfind.shiftamt.u32 %r15, %r14;
    shfl.sync.idx.b32 %r17, %r5, %r15, 31, %r1;
    mov.b32 %q6, %r17;
    @%p2 max.f32 %q3, %q3, %q6;
    shf.l.wrap.b32 %r19, -2, -2, %r15;
    and.b32 %r39, %r39, %r19;
    setp.ne.s32 %p2, %r39, 0;
    vote.sync.any.pred %p3, %p2, %r1;
    @!%p3 bra maybe_scatter;
    bra.uni slow_path_repeat;

fast_path:
    mov.b32 %r22, %q3;
    shfl.sync.down.b32 %r26, %r22, 16, 31, %r1;
    mov.b32 %q7, %r26;
    max.f32 %q8, %q7, %q3;
    mov.b32 %r27, %q8;
    shfl.sync.down.b32 %r29, %r27, 8, 31, %r1;
    mov.b32 %q9, %r29;
    max.f32 %q10, %q8, %q9;
    mov.b32 %r30, %q10;
    shfl.sync.down.b32 %r32, %r30, 4, 31, %r1;
    mov.b32 %q11, %r32;
    max.f32 %q12, %q10, %q11;
    mov.b32 %r33, %q12;
    shfl.sync.down.b32 %r34, %r33, 2, 31, %r1;
    mov.b32 %q13, %r34;
    max.f32 %q14, %q12, %q13;
    mov.b32 %r35, %q14;
    shfl.sync.down.b32 %r37, %r35, 1, 31, %r1;
    mov.b32 %q15, %r37;
    max.f32 %q3, %q14, %q15;
    mov.u32 %r40, 0;

maybe_scatter:
    mov.u32 %r38, %laneid;
    setp.ne.s32 %p13, %r40, %r38;
    @%p13 bra done;
    red.max.f32 [%rd0], %q3;

done:
    ret;
}

ptxas application ptx input, line 107; error   : Operation .max requires .u32 or .s32 or .u64 or .s64 type for instruction 'red'
ptxas fatal   : Ptx assembly aborted due to errors

I'm running on windows with an Nvidia RTX A1000 card and my cuda compiler is the following Cuda compilation tools, release 11.8, V11.8.89, Build cuda_11.8.r11.8/compiler.31833905_0.

It also seems like dr.ReduceOp.Mul is not supported on the two backends and fails (on the cuda backend) with :

ptxas application ptx input, line 107; error   : Unknown modifier '.mul'
ptxas application ptx input, line 107; error   : Illegal operation '' for instruction 'red'
ptxas application ptx input, line 107; error   : Operation  requires  type for instruction 'red'
ptxas application ptx input, line 107; error   : Reduction operation is required for instruction 'red'
ptxas fatal   : Ptx assembly aborted due to errors

and on the LLVM backend with:

drjit_bac90354b603f82059ca1b030eeec1f7:58:14: error: expected binary operation in atomicrmw
   atomicrmw fmul ptr %ptr_0, float %sum monotonic

Are those known limitations/issues of the scatter_reduce operator?

Best, Philippe

njroussel commented 1 year ago

Hi @WeiPhil

I wasn't aware of these fine details of scatter_reduce. By digging a bit, here's what I found:

In CUDA:

These are limitations of the red instruction in PTX (source).

In LLVM (assuming LLVM 16):

These are mostly restricted by the atomicrmw LLVVM IR instruction (source).

This has got me wondering why ReduceOp.Mul was added...

I'll keep this issue open until I figure out what we actually want to support. At the very least, what you are seeing now is "expected" behavior. I believe we could fully support this set of operations with integer and floating point types but it would require some more work (basically manually add some synchronization points). This was either never done because we have only needed ReduceOp.Add or because there is some other limitation I'm currently unaware of.

wjakob commented 1 year ago

ReduceOp.Mul is there because the plan was also to use this enumeration internally for horizontal reductions (exposed as drjit.prod in the upcoming nanobind rewrite). I agree that it's pretty weird for atomics.

I will explain these limitations in the documentation. Can the issue be closed?

wjakob commented 1 year ago

-> https://github.com/mitsuba-renderer/drjit-core/commit/9e864740814c3dbc40ace630bc7d61813aa76124

WeiPhil commented 1 year ago

Sounds good, thank you!