Closed jlperla closed 2 years ago
This is a trivial rrule. The full code for testing is
using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesTestUtils, FiniteDiff, ChainRulesCore, Test
x = [0.1, 0.2]
f(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f(x)
gradient(f, x)
#test_rrule(Zygote.ZygoteRuleConfig(), f, x; rrule_f = rrule_via_ad) # failing because of type stability.
@test gradient(f, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)
# Now a manual version of the primal from https://github.com/TuringLang/DistributionsAD.jl/blob/master/src/multivariate.jl#L157 dropping the mean and variance
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
y = f_manual(x)
function pb(f̄)
x̄ = -f̄ * x
return NoTangent(), x̄
end
return y, pb
end
gradient(f_manual, x)
test_rrule(Zygote.ZygoteRuleConfig(), f_manual, x; rrule_f = rrule_via_ad) # type stable
# make sure this is correct relative to the DistributionsAD one defined above.
@test gradient(f_manual, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)
From David Widmann: f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x)
might even be better if we want to use the AD directly.
using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesCore, FillArrays, Test, BenchmarkTools, StaticArrays
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
y = f_manual(x)
function pb(f̄)
x̄ = -f̄ * x
return NoTangent(), x̄
end
return y, pb
end
f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x)
f_2(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f_3(x) = logpdf(MvNormal(zero(x), I), x)
f_4(x) = logpdf(MvNormal(zero(x), ones(eltype(x), length(x))), x)
x = rand(100)
f(x)
@btime gradient(f, $x);
@btime gradient(f_2, $x);
@btime gradient(f_3, $x);
@btime gradient(f_4, $x);
@btime gradient(f_manual, $x);
x_static = @SArray [0.1, 0.2]
@btime gradient(f, $x_static);
@btime gradient(f_2, $x_static);
@btime gradient(f_3, $x_static);
@btime gradient(f_4, $x_static);
@btime gradient(f_manual, $x_static);
I get
15.100 μs (156 allocations: 16.06 KiB)
6.975 μs (79 allocations: 8.28 KiB)
8.367 μs (87 allocations: 9.78 KiB)
14.400 μs (148 allocations: 18.75 KiB)
239.516 ns (2 allocations: 1.75 KiB)
14.300 μs (161 allocations: 6.72 KiB)
5.883 μs (76 allocations: 3.30 KiB)
5.850 μs (76 allocations: 3.39 KiB)
12.500 μs (150 allocations: 5.97 KiB)
95.469 ns (0 allocations: 0 bytes)
For an accumulation, the differences are in the order of 200ish times faster (assuming I did it right). Which suggests the speed issues are not in the accumulation code but rather the overhead of the zygote calls. Code be wrong, but good enough for now
using Zygote, Distributions, DistributionsAD, LinearAlgebra, ChainRulesCore
function f(d, x)
return logpdf(d, x)
end
d = MvNormal(zeros(2), Diagonal([0.1, 0.2]))
x = [0.1, 0.2]
f(d, x)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(f), d, x)
y = f(d, x)
function pb(ȳ)
_, pb_logpdf = rrule_via_ad(config, logpdf, d, x)
_ignore, d̄, x̄ = pb_logpdf(ȳ)
return NoTangent(), d̄, x̄
end
return y, pb
end
gradient(f, d, x)
###############################################
using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesCore, FillArrays, Test, BenchmarkTools, StaticArrays
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
y = f_manual(x)
function pb(f̄)
x̄ = -f̄ * x
return NoTangent(), x̄
end
return y, pb
end
f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x)
f_2(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f_3(x) = logpdf(MvNormal(zero(x), I), x)
f_4(x) = logpdf(MvNormal(zero(x), ones(eltype(x), length(x))), x)
x = rand(100)
f(x)
@btime gradient(f, $x);
@btime gradient(f_2, $x);
@btime gradient(f_3, $x);
@btime gradient(f_4, $x);
@btime gradient(f_manual, $x);
x_static = @SArray [0.1, 0.2]
@btime gradient(f, $x_static);
@btime gradient(f_2, $x_static);
@btime gradient(f_3, $x_static);
@btime gradient(f_4, $x_static);
@btime gradient(f_manual, $x_static);
######### manual accumulation of loglikelihood
function g(v)
loglik = 0.0
N = size(v, 2)
d = MvNormal(FillArrays.Eye{Int}(N))
for i in 1:size(v, 1)
loglik += logpdf(d, @view v[i, :])
end
return loglik
end
g(rand(100, 10))
function g_manual_f(v)
loglik = 0.0
N = size(v, 2)
for i in 1:size(v, 1)
loglik += f_manual(@view v[i, :])
end
return loglik
end
g_manual_f(rand(100, 10))
function g_manual(v)
loglik = 0.0
N = size(v, 2)
for i in 1:size(v, 1)
loglik += f_manual(@view v[i, :])
end
return loglik
end
g_manual(rand(100, 10))
function ChainRulesCore.rrule(::typeof(g_manual), v)
y = g_manual(v)
function pb(f̄)
v̄ = -f̄ * v # is this the right pullback? I think so but worth doublechecking
# Went through an unnecessary loop just to make sure that isn't the difference, and it wasn't especially important
# v̄ = similar(v)
# for i in 1:size(v, 2)
# v̄[:, i] .= -f̄ * v[:, i]
# end
return NoTangent(), v̄
end
return y, pb
end
function g_callbacks(v)
loglik = 0.0
N = size(v, 2)
d = MvNormal(FillArrays.Eye{Int}(N))
for i in 1:size(v, 1)
loglik += logpdf(d, @view v[i, :])
end
return loglik
end
g_callbacks(rand(100, 10))
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(g_callbacks), v)
N = size(v, 2)
y = g_callbacks(v)
d = MvNormal(FillArrays.Eye{Int}(N))
function pb(f̄)
v̄ = similar(v)
for i in 1:size(v, 1)
v_i = @view v[i, :]
_, pb_logpdf = rrule_via_ad(config, logpdf, d, v_i)
_ignore, d̄, x̄ = pb_logpdf(f̄)
v̄[i, :] .= x̄ # I think?
end
return NoTangent(), v̄
end
return y, pb
end
function g_callbacks(v)
loglik = 0.0
N = size(v, 2)
d = MvNormal(FillArrays.Eye{Int}(N))
for i in 1:size(v, 1)
loglik += logpdf(d, @view v[i, :])
end
return loglik
end
g_callbacks(rand(100, 10))
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(g_callbacks), v)
N = size(v, 2)
y = g_callbacks(v)
d = MvNormal(FillArrays.Eye{Int}(N))
function pb(f̄)
v̄ = similar(v)
for i in 1:size(v, 1)
v_i = @view v[i, :]
_, pb_logpdf = rrule_via_ad(config, logpdf, d, v_i)
_ignore, d̄, x̄ = pb_logpdf(f̄)
v̄[i, :] .= x̄ # I think?
end
return NoTangent(), v̄
end
return y, pb
end
v = rand(100, 50)
println("using the likelihood directly\n")
@btime gradient(g, $v)
println("Using custom likelihood\n")
@btime gradient(g_manual_f, $v)
println("Using custom accumulation as well\n")
@btime gradient(g_manual, $v)
println("Using custom accumulation with callbacks\n")
@btime gradient(g_callbacks, $v)
# I am getting it about a > 200x faster for the custom accumulation than the basics
# The callbacks are roughly the same speed as the direct likelihood calls
Closing since the answer on the utility of a specialized unit normal in the iteration is resolved.
The purpose of this is to see how much faster it is to do a specialization for the unit normal observational noise with a diagonal covariance matrix vs. callbacks into the DistributionsAD. If this is faster, it means the entire rrule can be done without callbacks. Depends on #24 and #25
The code to verify the adjoint of the unit normal is