Closed oxinabox closed 3 years ago
Solution that @willtebbutt and I came up with the other night. Make all rules take an extra argument for the configuration. Which would be a struct:
struct Config{F, R}
ad_frule::F
ad_rrule::R
end
Where ad_frule
and ad_rrule
are functions conforming to the frule
/rrule
API, but that invoke a forward mode or reverse more AD.
Or they can be nothing
if either one is not provided.
We might want to specify that they have to be actual functions so that we can actually dispatch on them being present.
e.g. for rrule
function rrule(::typeof(map), config::Config{<:Any, <:Function}, f, xs)
y = map(f, xs...)
function map_pullback(ȳ)
ntuple(length(xs)+2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 2 && return DoesNotExist()
i = full_i-2
@thunk map(ȳ, xs...) do ȳi, xis...
_, pullback = ad_rrule(f, xis...)
∂xis = pullback(ȳi)
extern(∂xis[i+1]) #+1 to skp ∂self
end
end
end
return y, map_pullback
end
Which since we dispatched on Confiig
having a ad_rrule
function it will hit the generic nothing
rule-not-found fallback if it doesn't.
If it does, then we can assume the ad_rrule
will itself check for actual rrule
s.
One worry is the case that one is in some non-AD senario, but where one knows all the things should have rules.
I think for that case the user can set the Config
to use the current checked_rrule
/ checked_ffrule
which errors if the rule is not found.
@jrevels 's old idea that was in the source code for ages and i have ment to transfer is
In some weird ideal sense, the fallback for e.g.
frule
should actually be "get the derivative via forward-mode AD". This is necessary to enable mixed-mode rules, where e.g.frule
is used within arrule
definition. For example, broadcasted functions may not themselves be forward-mode primitives, but are often forward-mode differentiable. ChainRulesCore, by design, is decoupled from any specific AD implementation. How, then, do we know which AD to fall back to when there isn't a primitive defined? Well, if you're a greedy AD implementation, you can just overloadfrule
and/orrrule
to use your AD directly. However, this won't play nice with other AD packages doing the same thing, and thus could cause load-order-dependent problems for downstream users. It turns out, Cassette solves this problem nicely by allowing AD authors to overload the fallbacks w.r.t. their own context. Example using ForwardDiff:using ChainRulesCore, ForwardDiff, Cassette Cassette.@context MyChainRuleCtx # ForwardDiff, itself, can call `my_frule` instead of # `frule` to utilize the ForwardDiff-injected ChainRulesCore # infrastructure my_frule(args...) = Cassette.recurse(MyChainRuleCtx(), frule, args...) function Cassette.overdub(::MyChainRuleCtx, ::typeof(frule), f, x::Number) r = frule(f, x) if isa(r, Nothing) fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx) else fx, df = r end return fx, df end
Which could work. It would basically fix any used of checked_rrule
and checked_frule
in existing code, but rewriting them not to ever error since frule
would never return nothing
.
I know for purposes of allowing 2nd derivatives etc
@YingboMa is already overdubbiing frule
in ForwardDiff2
Which either makes this easier, because its just add an extra thing to existing overdub,
Or harder, because it ends up adding an extra layer of overdub, and bad things to performance happen when you nest Cassette.
A slight variation on the Config
option discussed above.
Firstly, add a couple of extra types:
abstract type AbstractAD{T} end
struct ForwardsModeAD{T} <: AbstractAD{T}
pushforward::T
end
struct ReverseModeAD{T} <: AbstractAD{T}
pullback::T
end
These simply wrap an AD e.g. Zygote.pullback
to produce a thing that ChainRules
knows about. ChainRules will then assume an API for the function.
Implement a fallback definition of frule
and rrule
:
frule(::AbstractAD, tangents, values...) = frule(tangents, values...)
rrule(::AbstractAD, values...) = rrule(values...)
This gives rule-authors two options.
frule
or rrule
that completely ignores any AD information.frule
or rrule
that exploits whichever AD is passed in.This gives AD package authors options various options:
map
ing and broadcast
ing unary functions of a scalar.I played around with a similar idea for broadcasting in Yota:
function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::F, args...) where F
df = get_deriv_function(Tuple{F, map(eltype, args)...})
return NO_FIELDS, df.(dy, f, args...)
end
Here get_deriv_function()
is a hardcoded version of ad_rrule
, so basically we retrieve rrule
* for args
elements and broadcast it to all arguments. It kinda works, but since rrule
returns a tuple, df.(dy, f, args...)
returns an array of tuples and not tuple of arrays. CPU arrays it's may not a big deal, but computing per-element derivatives and then combining them back on GPU will definitely destroy the performance.
Any idea how to deal with array-of-tuples vs tuple-or-arrays issue?
*
- strictly speaking, it's not rrule
, but a function with a similar signature and semantics.
Good question.
Zgyote does this stuff with StaticGetter
and unzip
idk how well optimied it is for GPU.
(it is much better than naive unzipping vis zip
on CPU)
https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
StaticGetter
still hits getindex()
, but the following seems to work:
julia> A = cu([(Zero(), 2, 3), (Zero(), 5, 6), (Zero(), 8, 9)])
3-element CuArray{Tuple{Zero, Int64, Int64}, 1}:
(Zero(), 2, 3)
(Zero(), 5, 6)
(Zero(), 8, 9)
julia> map(x -> x[1], A)
3-element CuArray{Zero, 1}:
Zero()
Zero()
Zero()
julia> map(x -> x[2], A)
3-element CuArray{Int64, 1}:
2
5
8
julia> map(x -> x[3], A)
3-element CuArray{Int64, 1}:
3
6
9
I'll try to do it for rrule()
assuming ad_rrule
is provided.
So far no luck with pullback-based systems:
julia> f = *
* (generic function with 379 methods)
julia> args = map(cu, (rand(2), rand(2)))
(Float32[0.37061554, 0.97347444], Float32[0.96509105, 0.7939103])
julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[3]:1
As far as I understand, CUDA.jl doesn't play well with any kind of closures, e.g. here's a simpler example without rrule
:
julia> x = 1.0
1.0
julia> foo(y) = x + y
foo (generic function with 1 method)
julia> foo.(rand(2))
2-element Vector{Float64}:
1.150095833562403
1.1280587660314911
julia> foo.(cu(rand(2)))
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(foo), Tuple{CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[18]:1
It might be possible to rewrite rrule
with something like Cassette to replace all calls to f(args...)
with broadcast(f, args...)
(or equivalent for other higher order functions), but it doesn't sound very robust.
Does the second example work if x
is a constant? I guess this should fix the type instability.
Indeed, making x
a constant in global scope fixes the issue in the example. For rrule
it still doesn't work though:
julia> const f = *
* (generic function with 402 methods)
julia> const args = map(cu, (rand(2), rand(2)))
(Float32[0.08670729, 0.5492601], Float32[0.24855424, 0.8392036])
julia> rrule.(f, args...)
ERROR: GPU broadcast resulted in non-concrete element type Union{}.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] copy
@ ~/.julia/packages/GPUArrays/bjw3g/src/host/broadcast.jl:44 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}})
@ Base.Broadcast ./broadcast.jl:883
[4] top-level scope
@ REPL[3]:1
julia> @code_warntype rrule.(f, args...)
Variables
#self#::Core.Const(var"##dotfunction#489#95"())
x1::Core.Const(*)
x2::Tuple{CuArray{Float32, 1}, CuArray{Float32, 1}}
Body::Union{}
1 ─ %1 = Core.tuple(Main.rrule, x1)::Core.Const((rrule, *))
│ %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x2)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(rrule), Tuple{Base.RefValue{typeof(*)}, CuArray{Float32, 1}, CuArray{Float32, 1}}}
│ Base.materialize(%2)
└── Core.Const(:(return %3))
I just checked your CUDA example, and everything is inferred correctly for me (even without const
). I used
(jl_OUkHOF) pkg> st
Status `/tmp/jl_OUkHOF/Project.toml`
[052768ef] CUDA v3.1.0
[082447d4] ChainRules v0.7.63
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
Environment:
JULIA_NUM_THREADS = 12
Thanks for checking it! Indeed, updating to the latest versions of the packages fixed that specific example, ~but unfortunately not more complex ones~ and even for more complex examples, e.g.:
using CUDA
foo(x::Number) = (2x, dy -> dy * x) # a simple rrule-like function
bar(x, dy) = ((y, pb) = foo(x); pb(x))
A = cu(rand(1024, 64))
bar.(A, 1f0)
# ok
But that's not really my concern, what I'm trying to check is how wrapping functions into closures affects performance on GPU. For example, when I write:
quux.(A)
and quux()
is a plain old function I'm pretty much sure CUDA.jl will be able to generate efficient kernel from it and apply this kernel to each element of A
. However, in the example above foo.(A)
returns an array of (tuples with) closures. Closure is a CPU object wrapping a scalar (single element of GPU array), and it doesn't sound very GPU-friendly. Even if run benchmarks and on simple examples they show good performance, we should be really careful not to accidentally kill CUDA's optimizations on more complex examples.
Some concrete experiments with rrule
, perhaps not the most optimal:
using ChainRules
using CUDA
using BenchmarkTools
args = map(cu, (rand(1024, 64), rand(1024, 64)))
r = rrule.(*, args...)
pbs = map(x -> x[2], r)
@btime map((pb, x) -> pb(x), $pbs, 1f0)
# ==> 18.085 μs (41 allocations: 976 bytes)
plain_deriv(dy, x, y) = (Zero(), dy * y, dy * x)
@btime plain_deriv.(1f0, args...)
# ==> 4.968 μs (26 allocations: 592 bytes)
I must admit pullback-based examples also runs in 5μs if I use a proper dy = cu(ones(1024, 64))
instead of 1f0
, yet the behavior above is quite unintuitive for me.
But maybe I'm over-complicating things and in practice everything will work just fine.
Here's an implementation of rrule
for broadcasted()
which works with CPU and GPU arrays as long as rrule.(f, args...)
works:
# from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end
function unzip(tuples)
N = length(first(tuples))
_unzip(tuples, Val(N))
end
function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule.(f, args...))
function pullback(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
return NO_FIELDS, unzip(dxs)...
end
return ys, pullback
end
I didn't notice any significant performance degradation compared to non-closure-based version, but rrule.()
fails on some examples e.g.:
rrule.(^, cu(rand(2)), 2f0)
Since rrule.()
is a placeholder for ad_rrule()
(or whatever we end up with) and ad_rrule()
may behave differently, I just stopped here and haven't investigated the error.
This was originally discussed in https://github.com/JuliaDiff/ChainRules.jl/issues/12#issuecomment-483058007 and in a few other places.
Basically often when defining a chainrule (
frule
, orrrule
) it would be nice to be able to say "Give me this chainrule for some of the function, and if there is not one predefined us AD to get it" as part of your rule definition. Right now this is not ppossible, except by hard coding a AD (e.g. Zygote) in.Where as if we had a function that an AD system could basically overload, then we could do that.
It would also provide a common API for all ADs that support it.
This would help e.g. with higher order functions like
map
,broadcast
etc https://github.com/JuliaDiff/ChainRules.jl/issues/122There is some fiddlyness involved, around making sure it is both overloadable multiple times and so the user can choise which AD, and that compiles away, but I think we can sort it all out.
@jessebett reminded me of this today