JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
255 stars 62 forks source link

rrule_via_ad // frule_via_ad Calling back into AD from ChainRules #68

Closed oxinabox closed 3 years ago

oxinabox commented 4 years ago

This was originally discussed in https://github.com/JuliaDiff/ChainRules.jl/issues/12#issuecomment-483058007 and in a few other places.

Basically often when defining a chainrule (frule, or rrule) it would be nice to be able to say "Give me this chainrule for some of the function, and if there is not one predefined us AD to get it" as part of your rule definition. Right now this is not ppossible, except by hard coding a AD (e.g. Zygote) in.

Where as if we had a function that an AD system could basically overload, then we could do that.

It would also provide a common API for all ADs that support it.

This would help e.g. with higher order functions like map, broadcast etc https://github.com/JuliaDiff/ChainRules.jl/issues/122

There is some fiddlyness involved, around making sure it is both overloadable multiple times and so the user can choise which AD, and that compiles away, but I think we can sort it all out.


@jessebett reminded me of this today

oxinabox commented 4 years ago

Solution that @willtebbutt and I came up with the other night. Make all rules take an extra argument for the configuration. Which would be a struct:

struct Config{F, R}
    ad_frule::F
    ad_rrule::R
end

Where ad_frule and ad_rrule are functions conforming to the frule/rrule API, but that invoke a forward mode or reverse more AD. Or they can be nothing if either one is not provided.

We might want to specify that they have to be actual functions so that we can actually dispatch on them being present. e.g. for rrule

function rrule(::typeof(map), config::Config{<:Any, <:Function}, f, xs)
    y = map(f, xs...)
    function map_pullback(ȳ)
        ntuple(length(xs)+2) do full_i
            full_i == 1 && return NO_FIELDS
            full_i == 2 && return DoesNotExist()
            i = full_i-2
            @thunk map(ȳ, xs...) do ȳi, xis...
                _, pullback = ad_rrule(f, xis...)
                ∂xis = pullback(ȳi)
                extern(∂xis[i+1])  #+1 to skp ∂self
            end
        end
    end
    return y, map_pullback
end

Which since we dispatched on Confiig having a ad_rrule function it will hit the generic nothing rule-not-found fallback if it doesn't. If it does, then we can assume the ad_rrule will itself check for actual rrules.

One worry is the case that one is in some non-AD senario, but where one knows all the things should have rules. I think for that case the user can set the Config to use the current checked_rrule/ checked_ffrule which errors if the rule is not found.

oxinabox commented 4 years ago

@jrevels 's old idea that was in the source code for ages and i have ment to transfer is

In some weird ideal sense, the fallback for e.g. frule should actually be "get the derivative via forward-mode AD". This is necessary to enable mixed-mode rules, where e.g. frule is used within a rrule definition. For example, broadcasted functions may not themselves be forward-mode primitives, but are often forward-mode differentiable. ChainRulesCore, by design, is decoupled from any specific AD implementation. How, then, do we know which AD to fall back to when there isn't a primitive defined? Well, if you're a greedy AD implementation, you can just overload frule and/or rrule to use your AD directly. However, this won't play nice with other AD packages doing the same thing, and thus could cause load-order-dependent problems for downstream users. It turns out, Cassette solves this problem nicely by allowing AD authors to overload the fallbacks w.r.t. their own context. Example using ForwardDiff:

using ChainRulesCore, ForwardDiff, Cassette
Cassette.@context MyChainRuleCtx
# ForwardDiff, itself, can call `my_frule` instead of
# `frule` to utilize the ForwardDiff-injected ChainRulesCore
# infrastructure
my_frule(args...) = Cassette.recurse(MyChainRuleCtx(), frule, args...)
function Cassette.overdub(::MyChainRuleCtx, ::typeof(frule), f, x::Number)
    r = frule(f, x)
    if isa(r, Nothing)
        fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx)
    else
        fx, df = r
    end
    return fx, df
end

Which could work. It would basically fix any used of checked_rrule and checked_frule in existing code, but rewriting them not to ever error since frule would never return nothing.

I know for purposes of allowing 2nd derivatives etc @YingboMa is already overdubbiing frule in ForwardDiff2 Which either makes this easier, because its just add an extra thing to existing overdub, Or harder, because it ends up adding an extra layer of overdub, and bad things to performance happen when you nest Cassette.

willtebbutt commented 4 years ago

A slight variation on the Config option discussed above.

Firstly, add a couple of extra types:

abstract type AbstractAD{T} end

struct ForwardsModeAD{T} <: AbstractAD{T}
    pushforward::T
end

struct ReverseModeAD{T} <: AbstractAD{T}
    pullback::T
end

These simply wrap an AD e.g. Zygote.pullback to produce a thing that ChainRules knows about. ChainRules will then assume an API for the function.

Implement a fallback definition of frule and rrule:

frule(::AbstractAD, tangents, values...) = frule(tangents, values...)

rrule(::AbstractAD, values...) = rrule(values...)

This gives rule-authors two options.

  1. implement a new method of frule or rrule that completely ignores any AD information.
  2. implement a new method of frule or rrule that exploits whichever AD is passed in.

This gives AD package authors options various options:

dfdx commented 3 years ago

I played around with a similar idea for broadcasting in Yota:

function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::F, args...) where F
    df = get_deriv_function(Tuple{F, map(eltype, args)...})
    return NO_FIELDS, df.(dy, f, args...)
end

Here get_deriv_function() is a hardcoded version of ad_rrule, so basically we retrieve rrule* for args elements and broadcast it to all arguments. It kinda works, but since rrule returns a tuple, df.(dy, f, args...) returns an array of tuples and not tuple of arrays. CPU arrays it's may not a big deal, but computing per-element derivatives and then combining them back on GPU will definitely destroy the performance.

Any idea how to deal with array-of-tuples vs tuple-or-arrays issue?


*- strictly speaking, it's not rrule, but a function with a similar signature and semantics.

oxinabox commented 3 years ago

Good question. Zgyote does this stuff with StaticGetter and unzip idk how well optimied it is for GPU. (it is much better than naive unzipping vis zip on CPU) https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185

dfdx commented 3 years ago

StaticGetter still hits getindex(), but the following seems to work:

julia> A = cu([(Zero(), 2, 3), (Zero(), 5, 6), (Zero(), 8, 9)])
3-element CuArray{Tuple{Zero, Int64, Int64}, 1}:
 (Zero(), 2, 3)
 (Zero(), 5, 6)
 (Zero(), 8, 9)

julia> map(x -> x[1], A)
3-element CuArray{Zero, 1}:
 Zero()
 Zero()
 Zero()

julia> map(x -> x[2], A)
3-element CuArray{Int64, 1}:
 2
 5
 8

julia> map(x -> x[3], A)
3-element CuArray{Int64, 1}:
 3
 6
 9

I'll try to do it for rrule() assuming ad_rrule is provided.

dfdx commented 3 years ago

So far no luck with pullback-based systems:

julia> f = *
* (generic function with 379 methods)

julia> args = map(cu, (rand(2), rand(2)))
(Float32[0.37061554, 0.97347444], Float32[0.96509105, 0.7939103])

julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[3]:1

As far as I understand, CUDA.jl doesn't play well with any kind of closures, e.g. here's a simpler example without rrule:

julia> x = 1.0
1.0

julia> foo(y) = x + y
foo (generic function with 1 method)

julia> foo.(rand(2))
2-element Vector{Float64}:
 1.150095833562403
 1.1280587660314911

julia> foo.(cu(rand(2)))
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(foo), Tuple{CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[18]:1

It might be possible to rewrite rrule with something like Cassette to replace all calls to f(args...) with broadcast(f, args...) (or equivalent for other higher order functions), but it doesn't sound very robust.

devmotion commented 3 years ago

Does the second example work if x is a constant? I guess this should fix the type instability.

dfdx commented 3 years ago

Indeed, making x a constant in global scope fixes the issue in the example. For rrule it still doesn't work though:

julia> const f = *
* (generic function with 402 methods)

julia> const args = map(cu, (rand(2), rand(2)))
(Float32[0.08670729, 0.5492601], Float32[0.24855424, 0.8392036])

julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] copy
   @ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
 [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
   @ Base.Broadcast ./broadcast.jl:883
 [4] top-level scope
   @ REPL[3]:1

julia> @code_warntype rrule.(f, args...)
Variables
  #self#::Core.Const(var"##dotfunction#489#95"())
  x1::Core.Const(*)
  x2::Tuple{CuArray{Float32, 1}, CuArray{Float32, 1}}

Body::Union{}
1 ─ %1 = Core.tuple(Main.rrule, x1)::Core.Const((rrule, *))
│   %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x2)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}}
│        Base.materialize(%2)
└──      Core.Const(:(return %3))
devmotion commented 3 years ago

I just checked your CUDA example, and everything is inferred correctly for me (even without const). I used

(jl_OUkHOF) pkg> st
      Status `/tmp/jl_OUkHOF/Project.toml`
  [052768ef] CUDA v3.1.0
  [082447d4] ChainRules v0.7.63

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
Environment:
  JULIA_NUM_THREADS = 12
dfdx commented 3 years ago

Thanks for checking it! Indeed, updating to the latest versions of the packages fixed that specific example, ~but unfortunately not more complex ones~ and even for more complex examples, e.g.:

using CUDA

foo(x::Number) = (2x, dy -> dy * x)   # a simple rrule-like function
bar(x, dy) = ((y, pb) = foo(x); pb(x))
A = cu(rand(1024, 64))

bar.(A, 1f0)
# ok

But that's not really my concern, what I'm trying to check is how wrapping functions into closures affects performance on GPU. For example, when I write:

quux.(A)

and quux() is a plain old function I'm pretty much sure CUDA.jl will be able to generate efficient kernel from it and apply this kernel to each element of A. However, in the example above foo.(A) returns an array of (tuples with) closures. Closure is a CPU object wrapping a scalar (single element of GPU array), and it doesn't sound very GPU-friendly. Even if run benchmarks and on simple examples they show good performance, we should be really careful not to accidentally kill CUDA's optimizations on more complex examples.


Some concrete experiments with rrule, perhaps not the most optimal:

using ChainRules
using CUDA
using BenchmarkTools

args = map(cu, (rand(1024, 64), rand(1024, 64)))
r = rrule.(*, args...)
pbs = map(x -> x[2], r)
@btime map((pb, x) -> pb(x), $pbs, 1f0)
# ==> 18.085 μs (41 allocations: 976 bytes)

plain_deriv(dy, x, y) = (Zero(), dy * y, dy * x)
@btime plain_deriv.(1f0, args...)
# ==> 4.968 μs (26 allocations: 592 bytes)

I must admit pullback-based examples also runs in 5μs if I use a proper dy = cu(ones(1024, 64)) instead of 1f0, yet the behavior above is quite unintuitive for me.

But maybe I'm over-complicating things and in practice everything will work just fine.

dfdx commented 3 years ago

Here's an implementation of rrule for broadcasted() which works with CPU and GPU arrays as long as rrule.(f, args...) works:

# from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
  Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end

function unzip(tuples)
  N = length(first(tuples))
  _unzip(tuples, Val(N))
end

function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
    ys, pbs = unzip(rrule.(f, args...))
    function pullback(Δ)
        dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
        return NO_FIELDS, unzip(dxs)...
    end
    return ys, pullback
end

I didn't notice any significant performance degradation compared to non-closure-based version, but rrule.() fails on some examples e.g.:

 rrule.(^, cu(rand(2)), 2f0)

Since rrule.() is a placeholder for ad_rrule() (or whatever we end up with) and ad_rrule() may behave differently, I just stopped here and haven't investigated the error.