Open torfjelde opened 3 years ago
Beautiful writeup! I don't think special-casing broadcasting for the user-facing API is the right approach here. map
, mapreduce
and other higher order functions are all supported by CUDA.jl for any CUDA-compatible code so I think we should also support that. I think a good approach now (no Cassette-like magic) would be to define different distributions for CUDA. Then inside Turing, we can use contexts to allow users to switch out some distributions or functions for others at compile time, e.g:
if _ctx isa GPUContext
dist = Normal(..)
else
dist = GPUNormal(...)
end
This requires more awareness by the user of which functions and structs are GPU compatible and which are not but this might be the best general and performant solution we have now. For distributions that we own such as filldist
and arraydist
, we can try to define logpdf
dispatch rules for things to "just work" but that's definitely in the hackish territory so we shouldn't push this too far imo.
Realistically, we may only need to define a handful of distributions that people use for this to be immediately useful. Then the rest is documentation.
* A lot of the functions used within the Distributions.jl ecosystem is _not_ pure Julia under the hood, but there are often pure-Julia versions for more generic number types so that one for example can just AD through them without issues, e.g. [JuliaStats/StatsFuns.jl@`8dfda2c`/src/distrs/gamma.jl#L19](https://github.com/JuliaStats/StatsFuns.jl/blob/8dfda2c0ee33d5f85eca5c039d31d85c90f363f2/src/distrs/gamma.jl#L19). BUT this doesn't help compat with CUDA.jl because elementypes of a `CUDA.CuArray` aren't special, i.e. it's just a `Float32`. And so the function we dispatch on when broadcasting over a `CUDA.CuArray` will be some function _outside_ of the Julia ecosystem, and so things starts blowing up.
This seems incorrect, StatsFuns only defines calls into Rmath for functions with arguments of type Union{Float64,Int}
(see https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/rmath.jl).
I'm a bit worried that all the special cases and implementations will lead to diverging implementations, as e.g. already happened with the SpecialFunctions customizations in https://github.com/JuliaGPU/CuArrays.jl/pull/321. If the automatic method substitution works at some point it might also become tricky to spot all these superfluous manual substitutions (but probably that's still too far away to be worried about).
I don't think special-casing broadcasting for the user-facing API is the right approach here.
This is a good point :+1: The Vectorized
thingy seems quite hacky in comparison.
we can use contexts to allow users to switch out some distributions or functions for others at compile time
This is a good idea for actually making it possible to use in Turing.jl. The other part is going to make Bijectors.jl compatible.
This seems incorrect
You're 100% right. I realize now I only ran into the error when using Float64
rather than Float32
. The issue of not "recursively" doing cufunc
is still an issue though.
I'm a bit worried that all the special cases and implementations will lead to diverging implementations
Agree with the first part here, but I think we could add these "manual" overloads in a CUDAExtensions.jl
package or something. Then when we have automatic overloads, we just stop using that package. The only change we need internally in CUDA.jl is to make CUDA.cufunc(f) = f
which doesn't change anything in CUDA.jl but allows us to write stuff like:
@cufunc SpecialFunctions.loggamma(x) = CUDA.lgamma(x)
and
@cufunc_register StatsFuns.normlogpdf
@cufunc_def StatsFuns.normlogpdf(z::Number) = -(abs2(z) + log2π)/2
@cufunc_def function StatsFuns.normlogpdf(μ::Real, σ::Real, x::Number)
if iszero(σ)
if x == μ
z = zval(μ, one(σ), x)
else
z = zval(μ, σ, x)
σ = one(σ)
end
else
z = zval(μ, σ, x)
end
normlogpdf(z) - log(σ)
end
which is a literal copy-paste from StatsFuns.jl. We could literally just clone repo, add a bunch of these macro-calls, and we'd have functional SpecialFunctions.jl and StatsFuns.jl on the GPU.
One thing that I haven't been able to do yet is to extract the diff-rule from DiffRules.jl and define one for the correspondig cufunc
. I was hoping I could just do something like extract the diff-rule by doing
diffrule_cu = replace_device_all(DiffRules.diffrule(mod, f, args...))
quote
DiffRules.@define_diffrule CUDAExtensions.$(cuname)(args...) = $(QuoteNode(diffrule_cu))
end
BUT the issue si that @define_diffrule
expects the RHS to be an expression with args
interpolated (i.e. $a
for a in args
). I can't seem to figure out how to generate such an expression on the RHS. As an example:
julia> nargs = 1;
julia> args = ntuple(i -> gensym(Symbol(:x, i)), nargs)
(Symbol("##x1#513"),)
julia> diffrule = DiffRules.diffrule(:Base, :sin, args...)
:(cos(var"##x1#513"))
julia> diffrule_cu = CUDAExtensions.replace_device_all(diffrule)
:((CUDA.cufunc(cos))(var"##x1#513"))
Then I'd like finally generate an expression that looks like
DiffRules.@define_diffrule CUDAExtensions.cusin(var"##x1#513") = :((CUDA.cufunc(cos))($var"##x1#513"))
(notice the $
on the RHS). The naming and all that stuff is fine, but the issue is replacing var"##x1#513"
in diffrule_cu
with $(var"##x1#513")
. Any ideas?
Any ideas?
Use eval. First walk the expression from DiffRules replacing every function call f(...)
with CUDA.cufunc(f)(...)
then create an expression that uses DiffRules.@define_diffrule
and eval
it.
Realistically, we may only need to define a handful of distributions that people use for this to be immediately useful. Then the rest is documentation.
Do you think it's reasonable to extract the key distribution implementation codes from DistributionsAD.jl into something like DistributionKernels.jl, in which we make sure the low-level primitives are also CUDA friendly.
I'm a bit worried that all the special cases and implementations will lead to diverging implementations, as e.g. already happened with the SpecialFunctions customizations in JuliaGPU/CuArrays.jl#321. If the automatic method substitution works at some point it might also become tricky to spot all these superfluous manual substitutions (but probably that's still too far away to be worried about).
which is a literal copy-paste from StatsFuns.jl. We could literally just clone repo, add a bunch of these macro-calls, and we'd have functional SpecialFunctions.jl and StatsFuns.jl on the GPU.
I guess David's worry is still here as we need to clone the repo and manually (or via GitHub Action) sync the codebase.
We could literally just clone repo, add a bunch of these macro-calls, and we'd have functional SpecialFunctions.jl and StatsFuns.jl on the GPU.
I am with @mohamed82008 on this one, I assume it would be better to just fix/implement whatever is really needed. StatsFuns and SpecialFunctions contain many things that are not needed (e.g., normlogpdf
is never used anywhere in Distributions or DistributionsAD). Maybe restricting ourselves to a small subset of commonly used functions would also make it more manageable to keep up with changes in these repos.
Do you think it's reasonable to extract the key distribution implementation codes from DistributionsAD.jl into something like DistributionKernels.jl, in which we make sure the low-level primitives are also CUDA friendly.
I don't think you can make it both CPU and GPU friendly. You need different structs targeting the CPU and GPU where the only difference is which version of log
and such they use. But yes I suggest having this in a separate package.
Regarding DistributionsAD: I opened a PR to StatsFuns https://github.com/JuliaStats/StatsFuns.jl/pull/106 to figure out how the StatsFuns type piracy can be resolved. That's only a tiny part of DistributionsAD but at least (almost) one file that could be removed.
Use eval. First walk the expression from DiffRules replacing every function call f(...) with CUDA.cufunc(f)(...) then create an expression that uses DiffRules.@define_diffrule and eval it.
Hmm, not sure I see what you mean. How do you get the $
in the expression on the RHS of the expression that you apply @define_diffrule
to?
I guess David's worry is still here as we need to clone the repo and manually (or via GitHub Action) sync the codebase.
I assume it would be better to just fix/implement whatever is really needed.
So I agree this would be better for maintenance, but it will be bloody annoying figuring out what we actually need. I also want to emphasize that this is only temporarily until we GPUCompiler.jl has proper method-substitution.
Also my motivation here is that we're trying to see what it would take to run this COVID19-model on the GPU, and I'm first just implementing the logjoint by hand. I have a working implementation where I've explicitly written out the logpdf's of the different distributions, but even that requires cufunc'ing some methods, e.g. loggamma
. And so if we're going to experiment a bit with different distributions, we at least need an easy way of defining these methods, e.g. @cufunc StatsFuns.gammalogpdf
, rather than manually going through the definition and replacing everything + defining diffrules. And if these are all put in a package called CUDAExtensions.jl rather than in just some notebook I have lying about, then suddenly we have something somewhat useful that people can take advantage of + at least allow us to play a bit with putting Turing.jl models on the GPU as defining distributions on GPU then becomes waaay easier.
I don't think you can make it both CPU and GPU friendly. You need different structs targeting the CPU and GPU where the only difference is which version of log and such they use.
Or you could do the Vectorized
thing I mentioned above to deal with UnivariateDistribution
when broadcasted and then we can specialize implementations of logpdf
, etc. for MultivariateDistribution
and MatrixDistribution
for CuArray
parameters and samples. I.e. the main issue that might require new structs is mainly UnivariateDistribution
.
@torfjelde is this what you want?
nargs = 1
args = ntuple(i -> gensym(Symbol(:x, i)), nargs)
diffrule = DiffRules.diffrule(:Base, :sin, args...)
diffrule_cu = CUDAExtensions.replace_device_all(diffrule)
@eval begin
DiffRules.@define_diffrule CUDAExtensions.cusin($args...,) = $diffrule_cu
end
- Wait until the work on method-substitution is done.
What about
observe
and assume
) with Cassette and replace functions with their GPU counterpartsas a temporary workaround?
Maybe this could be a way to avoid having to manually copy-paste and annotate function definitions from StatsFuns, SpecialFunctions etc.
One could use cufunc
for determining the GPU-compatible function (as a fallback at least). If one wants to be more restrictive, one could also keep a whitelist of functions that are replaced.
is this what you want?
So this is exactly what I did, but the issue is that DiffRules.@define_diffrule
expects the RHS to contain the args
as "interpolated" symbols, e.g.
DiffRules.@define_diffrule CUDAExtensions.cusin(x) = :((CUDA.cufunc(cos))($x))
Notice the $x
on the right-hand side. So your code produces
julia> import CUDAExtensions, DiffRules, MacroTools
julia> args = ntuple(i -> Symbol(:x, i), 1)
(:x1,)
julia> diffrule = DiffRules.diffrule(:Base, :sin, args...)
:(cos(x1))
julia> diffrule_cu = CUDAExtensions.replace_device_all(diffrule)
:((CUDA.cufunc(cos))(x1))
julia> ex = :(DiffRules.@define_diffrule CUDAExtensions.cusin($(args...)) = $diffrule_cu);
julia> MacroTools.prettify(ex)
:(DiffRules.@define_diffrule CUDAExtensions.cusin(x1) = (CUDA.cufunc(cos))(x1))
Issues with the above:
x1
on both sides (this isn't clear from the above example, but you can see this if you used gensym
instead).$
in front of the argument on RHS as required by @define_diffrule
.I essentially it to be:
:(DiffRules.@define_diffrule CUDAExtensions.cusin(x1) = :((CUDA.cufunc(cos))($x1)))
So first thing I try is to add QuoteNode
to RHS:
julia> ex = :(DiffRules.@define_diffrule CUDAExtensions.cusin($(args...)) = $(QuoteNode(diffrule_cu)));
julia> MacroTools.prettify(ex)
:(DiffRules.@define_diffrule CUDAExtensions.cusin(mouse) = $(QuoteNode(:((CUDA.cufunc(cos))(var"##x1#348")))))
So at least now the RHS is quoted, but there's not reference to the argument on LHS. I then try to add in the $
:
julia> rhs = MacroTools.postwalk(diffrule_cu) do e
if e in args
Expr(:$, e)
else
e
end
end
:((CUDA.cufunc(cos))($(Expr(:$, :x1))))
Combining it all together and we get:
julia> rhs = MacroTools.postwalk(:(QuoteNode($diffrule_cu))) do e
if e in args
Expr(:$, e)
else
e
end
end
:(QuoteNode((CUDA.cufunc(cos))($(Expr(:$, Symbol("##x1#348"))))))
julia> @eval begin
DiffRules.@define_diffrule CUDAExtensions.cusin($(args...)) = $rhs
end
ERROR: syntax: "$" expression outside quote around REPL[46]:2
Stacktrace:
[1] top-level scope at none:1
[2] eval(::Module, ::Any) at ./boot.jl:331
[3] top-level scope at REPL[46]:1
And herein lies my problem! Not sure if I'm doing something wrong or if it's just not possible to even do.
EDIT: Actually, I figured out how do it! I need to use Meta.quot
rather than QuoteNode
(which protects against interpolation)!
The following works:
julia> args = ntuple(i -> gensym(Symbol(:x, i)), 1)
(Symbol("##x1#348"),)
julia> diffrule = DiffRules.diffrule(:Base, :sin, args...)
:(cos(var"##x1#348"))
julia> diffrule_cu = CUDAExtensions.replace_device_all(diffrule)
:((CUDA.cufunc(cos))(var"##x1#348"))
julia> rhs = MacroTools.postwalk(Meta.quot(diffrule_cu)) do e
if e in args
Expr(:$, e)
else
e
end
end
:($(Expr(:quote, :((CUDA.cufunc(cos))($(Expr(:$, Symbol("##x1#348"))))))))
julia> MacroTools.prettify(:(DiffRules.@define_diffrule CUDAExtensions.cusin($(args...)) = $rhs))
:(DiffRules.@define_diffrule CUDAExtensions.cusin(mouse) = $(Expr(:quote, :((CUDA.cufunc(cos))($(Expr(:$, :mouse)))))))
julia> eval(:(DiffRules.@define_diffrule CUDAExtensions.cusin($(args...)) = $rhs))
(:CUDAExtensions, :cusin, 1)
julia> DiffRules.diffrule(:CUDAExtensions, :cusin, :x)
:((CUDA.cufunc(cos))(x))
:tada:
- Wait until the work on method-substitution is done.
What about
- Overdub model execution (or at least
observe
andassume
) with Cassette and replace functions with their GPU counterpartsas a temporary workaround?
Maybe this could be a way to avoid having to manually copy-paste and annotate function definitions from StatsFuns, SpecialFunctions etc.
One could use
cufunc
for determining the GPU-compatible function (as a fallback at least). If one wants to be more restrictive, one could also keep a whitelist of functions that are replaced.
I actually wrote this as an alternative and then removed it, haha. The reason is that such an approach has already been heavily considered it seems: https://github.com/JuliaGPU/CUDAnative.jl/pull/334. They have then moved onto this: https://github.com/JuliaGPU/GPUCompiler.jl/pull/122, i.e. waiting for them to come up with a proper solution to all of this. (I got this from this comment: https://github.com/JuliaGPU/CuArrays.jl/pull/321#issuecomment-753955394)
Makes sense that CUDA/GPUCompiler uses a different approach if it is better supported by the compiler. But the Cassette approach seems simple enough to rewrite your (or Turing's) functions on a smaller level as a workaround until the method substitution is integrated in CUDA. Apparently method substitution is planned already for the next version, so I understand that they are not too eager to integrate GPU-copies of SpecialFunctions and StatsFuns.
Having almost the same code in two packages will definitely be more difficult to maintain and I am also a bit worried that numerical issues are not tracked and fixed as carefully as in the original packages and fixes and other changes are not propagated. For instance, your comment in the draft of the SpecialFunctions PR should rather be addressed in the original package than the PR. Maybe it would be easier to keep such fixes temporary if they are not integrated in CUDA but live in a separate package even if it is type piracy. Something like GPUDistributionsAD :stuck_out_tongue:
Ugh, more issues. If we have
@cufunc SpecialFunctions.loggamma(x) = CUDA.lgamma(x)
@cufunc SpecialFunctions.logbeta(x, y) = SpecialFunction.loggamma(x) + SpecialFunctions.loggamma(y) - SpecialFunctions.loggamma(x + y)
in my CUDAExtensions
module, we get
julia> using CUDAExtensions
julia> import SpecialFunctions
julia> CUDA.cufunc(SpecialFunctions.loggamma) # verify that the we have function to replace
culoggamma (generic function with 2 methods)
julia> a = exp.(CUDA.randn(10));
julia> SpecialFunctions.loggamma.(a)
10-element CuArray{Float32,1}:
0.13996097
0.2567964
-0.10765605
-0.018245425
0.76336765
0.3486092
0.12353684
1.3368652
0.42304638
-0.114002444
julia> import ForwardDiff
julia> ForwardDiff.jacobian(a) do a
SpecialFunctions.loggamma.(a)
end
10×10 CuArray{Float64,2}:
0.58695 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
-0.0 -1.21057 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
-0.0 -0.0 -0.172639 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
0.0 0.0 0.0 0.393549 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.951943 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.745529 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.571245 0.0 0.0 0.0
-0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -4.43018 -0.0 -0.0
-0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -1.59937 -0.0
-0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.125148
julia> CUDA.cufunc(SpecialFunctions.logbeta) # verify that the we have function to replace
culogbeta (generic function with 15 methods)
julia> SpecialFunctions.logbeta.(a, a)
FATAL ERROR: Symbol "__nv_lgammaf"not found
signal (6): Aborted
in expression starting at REPL[9]:1
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
...
unknown function (ip: 0x4015d4)
Allocations: 75486245 (Pool: 75461987; Big: 24258); GC: 64
Aborted
Seems like it happens because the additional calls to cufunc
means that the compiler can't do type-inference for the entire broadcast operation:
julia> using CUDAExtensions, SpecialFunctions
julia> a = exp.(CUDA.randn(10));
julia> f(a) = SpecialFunctions.loggamma.(a)
f (generic function with 1 method)
julia> @code_warntype f(a)
Variables
#self#::Core.Compiler.Const(f, false)
a::CuArray{Float32,1}
Body::CuArray{Float32,1}
1 ─ %1 = SpecialFunctions.loggamma::Core.Compiler.Const(SpecialFunctions.loggamma, false)
│ %2 = Base.broadcasted(%1, a)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(CUDAExtensions.culoggamma),Tuple{CuArray{Float32,1}}}
│ %3 = Base.materialize(%2)::CuArray{Float32,1}
└── return %3
julia> g(a) = SpecialFunctions.logbeta.(a, a)
g (generic function with 1 method)
julia> @code_warntype g(a)
Variables
#self#::Core.Compiler.Const(g, false)
a::CuArray{Float32,1}
Body::CuArray{_A,1} where _A
1 ─ %1 = SpecialFunctions.logbeta::Core.Compiler.Const(SpecialFunctions.logbeta, false)
│ %2 = Base.broadcasted(%1, a, a)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(CUDAExtensions.culogbeta),Tuple{CuArray{Float32,1},CuArray{Float32,1}}}
│ %3 = Base.materialize(%2)::CuArray{_A,1} where _A
└── return %3
While if I circumvent the additional calls to cufunc
:
julia> logbeta(a, b) = CUDA.lgamma(a) + CUDA.lgamma(b) - CUDA.lgamma(a + b)
logbeta (generic function with 1 method)
julia> h(a) = logbeta.(a, a)
h (generic function with 1 method)
julia> @code_warntype h(a)
Variables
#self#::Core.Compiler.Const(h, false)
a::CuArray{Float32,1}
Body::CuArray{Float32,1}
1 ─ %1 = Base.broadcasted(Main.logbeta, a, a)::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(logbeta),Tuple{CuArray{Float32,1},CuArray{Float32,1}}}
│ %2 = Base.materialize(%1)::CuArray{Float32,1}
└── return %2
julia> h.(a)
10-element CuArray{Float32,1}:
-2.9278803
1.8021662
-3.8931153
-1.909847
0.8930918
-1.0019481
2.219266
0.27762663
1.3281798
2.3685691
Makes sense that CUDA/GPUCompiler uses a different approach if it is better supported by the compiler. But the Cassette approach seems simple enough to rewrite your (or Turing's) functions on a smaller level as a workaround until the method substitution is integrated in CUDA.
Almost certainly the issue I found above for the cufunc
method will be present when using Cassette too, i.e. type-inference will fail with Cassette :confused:
Having almost the same code in two packages will definitely be more difficult to maintain and I am also a bit worried that numerical issues are not tracked and fixed as carefully as in the original packages and fixes and other changes are not propagated. For instance, your comment in the draft of the SpecialFunctions PR should rather be addressed in the original package than the PR. Maybe it would be easier to keep such fixes temporary if they are not integrated in CUDA but live in a separate package even if it is type piracy. Something like GPUDistributionsAD
100% agree. I don't think the suggestion of having separate packages is a viable solution long-term.
Just for completeness, the code I've been working with is here: https://github.com/torfjelde/CUDAExtensions.jl
Made an issue related to the type-inference bug from above: https://github.com/torfjelde/CUDAExtensions.jl/issues/1
Another challenge we didn't highlight here is generating random numbers in the GPU kernel. This is not trivial because we need to maintain a different RNG for each thread and we need a GPU-compatible RNG. There is however an example of using a GPU-friendly RNG in https://github.com/JuliaFolds/FoldsCUDA.jl/blob/master/examples/monte_carlo_pi.jl.
Yeah I've been following that:) The only issue is that it seems like on CUDA 3.0 and Julia 1.6 neither the hacky approach I've taken nor the method introduced works since their method override approach is internal to CUDA.jl, AFAIK?
Yes, it seems it only becomes available for other packages in Julia 1.7.
One major difficulty of using CUDA.jl + anything Bayesian is that you immediately need to define all the (at very least, univariate) distributions all over again. But for starters we don't need all of the functionality of a
Distribution
to do something like Bayesian inference, e.g. for AdvancedHMC.jl we really only needlogpdf
and its adjoint as the only call torand
is going to be for the momentum (which can be sampled directly on the GPU usingCUDA.randn
).But even trying to redefine the
logpdf
of aDistribution
to work on the GPU is often non-trivial.Issue #1
Before going into the issue, it's important to know the following:
Base.log
andCUDA.log
are actually different methods.f.(args...)
toCUDA.cufunc(f).(args...)
whenever any of theargs
is aCUDA.CuArray
. by overloading theBroadcast.broadcasted
. E.g.CUDA.cufunc(::typeof(log)) = CUDA.log
.f
does not have acufunc
already defined and you dof.(args...)
you'll, if you're lucky, get an error but sometimes the entire Julia session will crash.So what do we do?
cufunc(::typeof(f)) = ...
which will allow you to broadcast overdigamma
from this PR by @xukai92: https://github.com/JuliaGPU/CuArrays.jl/pull/321#issuecomment-753998893)loggamma
can be replaced byCUDA.lgamma
.f
.f
on GPU Zygote.jl uses ForwardDiff.jl to obtain the adjoints for broadcasting and so we gotta define these rules usingDiffRules
and evaluate usingForwardDiff
.DiffRules
definition is forf
and you copy-paste, replacing methods withcufunc
methods.cufunc
definitions and their corresponding@define_diffrule
, so we're good to go right? Now we can just callStatsFuns.gammalogpdf.(α, θ, x)
, right?CUDA.CuArray
aren't special, i.e. it's just aFloat32
. And so the function we dispatch on when broadcasting over aCUDA.CuArray
will be some function outside of the Julia ecosystem, and so things starts blowing up.~ EDIT: this is only an issue for eltypeFloat64
, notFloat32
as pointed out by @devmotion!Float32
and so on to use the pure-Julia implementation, e.g.broadcasted
won't be properly nested, socufunc
will only be called onStatsFuns.gammalogpdf
, not on the methods used within! So, we instead dowhich is really not fun.
cufunc
for the leaves of the method hierarchy! This sucks.Potential solutions
CUDA.@cufunc
macro that I've implemented (https://github.com/JuliaGPU/CUDA.jl/blob/c011ffc0971ab1089f9d56dd338ef4b31e24ecc7/src/broadcast.jl#L101-L112) which has the following additional features:f
in the body withcufunc(f)
, with the default impl ofcufunc(f) = f
. I.e. do nothing to almost all methods, but those which have acufunc
impl we replace.@cufunc SpecialFunctions.gamma(x) = ...
is converted intocugamma(x) = ...; cufunc(::typeof(SpecialFunctions.gamma)) = cugamma
.f
is present inDiffRule.diffrules()
, then we extract the corresponding diffrule and replaces all functionsg
within the diffrule withcufunc(g)
. I.e. IF there is a scalar-rule forf
, then we make it CUDA compatible (assuming the methods in the rule has acufunc
implementation), otherwise we leave it to ForwardDiff.jl to figure it out.Personally, I'm in favour of solution (1).
Issue #2
Now, there's also an additional "annoyance" even after solving the above issue. We cannot do something like
logpdf.(Gamma.(α, θ), x)
because this will first to domap(Gamma, ...)
before callinglogpdf
. There's the possibility that this could have been inlined into completely removing the call toGamma
once it's sufficiently lowered, butGPUCompiler.jl
will complain before it reaches that stage (as this is not always a guarantee + I believe it will try to fuse all the broadcasts together into a single operation for efficiency). Therefore we either need to:gammalogpdf.(α, θ, x)
.Vectorize(D, args)
, e.g.Vectorize(Gamma, (α, θ))
, which has alogpdf
that lazily calls the underlying method, e.g.logpdf(v::Vectorize{Gamma}, x) = gammalogpdf.(v.args..., x)
. Equipped with this, we can speed up implementation quite a bit by potentially doing something like:broadcasted
so that if we're using theCUDA.CuArrayStyle
andf <: UnivariateDistribution
we canmateralize
args
earlier and then wrap it inVectorize
, i.e.Vectorize(f, args)
.Distributions.@__delegate_statsfuns
or whatever to more easily definelogpdf(v::Vectorize{D}, x)
for differentD
. Worth mentioning that this requires a small redef of this method in Zygote (https://github.com/FluxML/Zygote.jl/blob/2b17256e79b2eca9a6512207284219d279398fc9/src/lib/broadcast.jl#L225-L228), though it should def. be possible to make it work even though we're overloadingbroadcasted
.