JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

Pullback fails to inline with intervals #248

Open dpsanders opened 4 years ago

dpsanders commented 4 years ago

The following MWE does not inline when using intervals from IntervalArithmetic.jl:

function fff(x, y)
    z, z_pullback = rrule(*, x, y)

    z̄ = one(x)
    _, r1, r2 = z_pullback(z̄)

    x̄ = unthunk(r1)
    ȳ = unthunk(r2)

    return (x̄, ȳ)
end

using IntervalArithmetic

julia> @code_native f(1..1, 2..2)
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
    pushq   %rbx
    subq    $160, %rsp
    movq    %rdi, %rbx
; │ @ REPL[23]:2 within `f'
    movabsq $rrule, %rax
    leaq    112(%rsp), %rdi
    callq   *%rax
    movabsq $5516288192, %rax       ## imm = 0x148CBE0C0
; │ @ REPL[23]:5 within `f'
; │┌ @ fastmath_able.jl:188 within `times_pullback'
    vmovaps (%rax), %xmm0
    vmovups %xmm0, 80(%rsp)
    vmovups 128(%rsp), %xmm1
    vmovups 144(%rsp), %xmm2
    vmovups %xmm2, 96(%rsp)
    vmovups %xmm0, 48(%rsp)
    vmovups %xmm1, 64(%rsp)
; │└
; │ @ REPL[23]:7 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
    movabsq $"#461", %rax
    leaq    16(%rsp), %rdi
    leaq    80(%rsp), %rsi
    callq   *%rax
; │└└
; │ @ REPL[23]:8 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
    movabsq $"#462", %rax
    movq    %rsp, %rdi
    leaq    48(%rsp), %rsi
    callq   *%rax
; │└└
; │ @ REPL[23]:10 within `f'
    vmovups (%rsp), %xmm0
    vmovups %xmm0, 32(%rsp)
    vmovups 16(%rsp), %ymm0
    vmovups %ymm0, (%rbx)
    movq    %rbx, %rax
    addq    $160, %rsp
    popq    %rbx
    vzeroupper
    retq
    nopw    %cs:(%rax,%rax)
    nopl    (%rax)
; └

Cf. the beautiful code when using floats:

julia> @code_native f(1.0, 2.0)
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
    movq    %rdi, %rax
; │ @ REPL[23]:10 within `f'
    vmovsd  %xmm1, (%rdi)
    vmovsd  %xmm0, 8(%rdi)
    retq
    nopl    (%rax)
; └
oxinabox commented 4 years ago

I wonder if we need to open an unstream issue on julia itself about this.

dpsanders commented 4 years ago

I tried

@inline function rrule(::typeof(*), x::Number, y::Number)
    @inline function times_pullback(ΔΩ)
        return (NO_FIELDS,  ΔΩ * y', x' * ΔΩ)
    end
    return x * y, times_pullback
end

which seems to work correctly with intervals.

Time to Inline All The Things?

oxinabox commented 4 years ago

We can at least add it to @scalar_rule.

I am hesitant to add it to everything, nor to add it to the best practices; because visual noise. but we might want to at least mentioned it as a thing that can be considered in best practices?