SciML / DifferenceEquations.jl

Solving difference equations with DifferenceEquations.jl and the SciML ecosystem.
MIT License
32 stars 6 forks source link

Benchmark custom code for the logpdf of a unit normal vs. the chainrules callbacks #26

Closed jlperla closed 2 years ago

jlperla commented 2 years ago

The purpose of this is to see how much faster it is to do a specialization for the unit normal observational noise with a diagonal covariance matrix vs. callbacks into the DistributionsAD. If this is faster, it means the entire rrule can be done without callbacks. Depends on #24 and #25

The code to verify the adjoint of the unit normal is

using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesTestUtils, FiniteDiff, Test
x = [0.1, 0.2]
f(x) = logpdf(MvNormal(zeros(length(x)), I), x)
f(x)
gradient(f, x)
#test_rrule(Zygote.ZygoteRuleConfig(), f, x; rrule_f = rrule_via_ad) # not sure why this didn't work?
@test gradient(f, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)

# Now a manual version of the primal from https://github.com/TuringLang/DistributionsAD.jl/blob/master/src/multivariate.jl#L157 dropping the mean and variance
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
gradient(f_manual, x)
f_manual(x)
@test gradient(f_manual, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)
jlperla commented 2 years ago

This is a trivial rrule. The full code for testing is

using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesTestUtils, FiniteDiff, ChainRulesCore, Test
x = [0.1, 0.2]
f(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f(x)
gradient(f, x)
#test_rrule(Zygote.ZygoteRuleConfig(), f, x; rrule_f = rrule_via_ad) # failing because of type stability.
@test gradient(f, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)

# Now a manual version of the primal from https://github.com/TuringLang/DistributionsAD.jl/blob/master/src/multivariate.jl#L157 dropping the mean and variance
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
    y = f_manual(x)

    function pb(f̄)
        x̄ = -f̄ * x
        return NoTangent(), x̄
    end
    return y, pb
end

gradient(f_manual, x)
test_rrule(Zygote.ZygoteRuleConfig(), f_manual, x; rrule_f = rrule_via_ad) # type stable

# make sure this is correct relative to the DistributionsAD one defined above.
@test gradient(f_manual, x)[1] ≈ FiniteDiff.finite_difference_gradient(f, x)
jlperla commented 2 years ago

From David Widmann: f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x) might even be better if we want to use the AD directly.

jlperla commented 2 years ago
using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesCore, FillArrays, Test, BenchmarkTools, StaticArrays
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
    y = f_manual(x)

    function pb(f̄)
        x̄ = -f̄ * x
        return NoTangent(), x̄
    end
    return y, pb
end

f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x)
f_2(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f_3(x) = logpdf(MvNormal(zero(x), I), x)
f_4(x) = logpdf(MvNormal(zero(x), ones(eltype(x), length(x))), x)

x = rand(100)
f(x)
@btime gradient(f, $x);
@btime gradient(f_2, $x);
@btime gradient(f_3, $x);
@btime gradient(f_4, $x);
@btime gradient(f_manual, $x);

x_static = @SArray [0.1, 0.2]
@btime gradient(f, $x_static);
@btime gradient(f_2, $x_static);
@btime gradient(f_3, $x_static);
@btime gradient(f_4, $x_static);
@btime gradient(f_manual, $x_static);

I get

  15.100 μs (156 allocations: 16.06 KiB)

  6.975 μs (79 allocations: 8.28 KiB)

  8.367 μs (87 allocations: 9.78 KiB)

  14.400 μs (148 allocations: 18.75 KiB)

  239.516 ns (2 allocations: 1.75 KiB)

  14.300 μs (161 allocations: 6.72 KiB)

  5.883 μs (76 allocations: 3.30 KiB)

  5.850 μs (76 allocations: 3.39 KiB)

  12.500 μs (150 allocations: 5.97 KiB)

  95.469 ns (0 allocations: 0 bytes)
jlperla commented 2 years ago

For an accumulation, the differences are in the order of 200ish times faster (assuming I did it right). Which suggests the speed issues are not in the accumulation code but rather the overhead of the zygote calls. Code be wrong, but good enough for now

using Zygote, Distributions, DistributionsAD, LinearAlgebra, ChainRulesCore
function f(d, x)
    return logpdf(d, x)
end

d = MvNormal(zeros(2), Diagonal([0.1, 0.2]))
x = [0.1, 0.2]
f(d, x)

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(f), d, x)
    y = f(d, x)

    function pb(ȳ)
        _, pb_logpdf = rrule_via_ad(config, logpdf, d, x)
        _ignore, d̄, x̄ = pb_logpdf(ȳ)
        return NoTangent(), d̄, x̄
    end
    return y, pb
end

gradient(f, d, x)

###############################################
using Distributions, DistributionsAD, LinearAlgebra, Zygote, ChainRulesCore, FillArrays, Test, BenchmarkTools, StaticArrays
f_manual(x) = -(length(x) * log(2π) + sum(abs2.(x))) / 2
function ChainRulesCore.rrule(::typeof(f_manual), x)
    y = f_manual(x)

    function pb(f̄)
        x̄ = -f̄ * x
        return NoTangent(), x̄
    end
    return y, pb
end

f(x) = logpdf(MvNormal(FillArrays.Eye{Int}(length(x))), x)
f_2(x) = logpdf(MvNormal(FillArrays.Zeros{Int}(length(x)), I), x)
f_3(x) = logpdf(MvNormal(zero(x), I), x)
f_4(x) = logpdf(MvNormal(zero(x), ones(eltype(x), length(x))), x)

x = rand(100)
f(x)
@btime gradient(f, $x);
@btime gradient(f_2, $x);
@btime gradient(f_3, $x);
@btime gradient(f_4, $x);
@btime gradient(f_manual, $x);

x_static = @SArray [0.1, 0.2]
@btime gradient(f, $x_static);
@btime gradient(f_2, $x_static);
@btime gradient(f_3, $x_static);
@btime gradient(f_4, $x_static);
@btime gradient(f_manual, $x_static);

######### manual accumulation of loglikelihood

function g(v)
    loglik = 0.0
    N = size(v, 2)
    d = MvNormal(FillArrays.Eye{Int}(N))
    for i in 1:size(v, 1)
        loglik += logpdf(d, @view v[i, :])
    end
    return loglik
end

g(rand(100, 10))

function g_manual_f(v)
    loglik = 0.0
    N = size(v, 2)
    for i in 1:size(v, 1)
        loglik += f_manual(@view v[i, :])
    end
    return loglik
end
g_manual_f(rand(100, 10))

function g_manual(v)
    loglik = 0.0
    N = size(v, 2)
    for i in 1:size(v, 1)
        loglik += f_manual(@view v[i, :])
    end
    return loglik
end
g_manual(rand(100, 10))

function ChainRulesCore.rrule(::typeof(g_manual), v)
    y = g_manual(v)

    function pb(f̄)
        v̄ = -f̄ * v  # is this the right pullback?  I think so but worth doublechecking

        # Went through an unnecessary loop just to make sure that isn't the difference, and it wasn't especially important
        # v̄ = similar(v)
        # for i in 1:size(v, 2)
        #     v̄[:, i] .= -f̄ * v[:, i]
        # end
        return NoTangent(), v̄
    end
    return y, pb
end

function g_callbacks(v)
    loglik = 0.0
    N = size(v, 2)
    d = MvNormal(FillArrays.Eye{Int}(N))
    for i in 1:size(v, 1)
        loglik += logpdf(d, @view v[i, :])
    end
    return loglik
end

g_callbacks(rand(100, 10))

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(g_callbacks), v)
    N = size(v, 2)
    y = g_callbacks(v)
    d = MvNormal(FillArrays.Eye{Int}(N))

    function pb(f̄)
        v̄ = similar(v)
        for i in 1:size(v, 1)
            v_i = @view v[i, :]
            _, pb_logpdf = rrule_via_ad(config, logpdf, d, v_i)
            _ignore, d̄, x̄ = pb_logpdf(f̄)
            v̄[i, :] .= x̄  # I think?
        end
        return NoTangent(), v̄
    end
    return y, pb
end

function g_callbacks(v)
    loglik = 0.0
    N = size(v, 2)
    d = MvNormal(FillArrays.Eye{Int}(N))
    for i in 1:size(v, 1)
        loglik += logpdf(d, @view v[i, :])
    end
    return loglik
end

g_callbacks(rand(100, 10))

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(g_callbacks), v)
    N = size(v, 2)
    y = g_callbacks(v)
    d = MvNormal(FillArrays.Eye{Int}(N))

    function pb(f̄)
        v̄ = similar(v)
        for i in 1:size(v, 1)
            v_i = @view v[i, :]
            _, pb_logpdf = rrule_via_ad(config, logpdf, d, v_i)
            _ignore, d̄, x̄ = pb_logpdf(f̄)
            v̄[i, :] .= x̄  # I think?
        end
        return NoTangent(), v̄
    end
    return y, pb
end

v = rand(100, 50)
println("using the likelihood directly\n")
@btime gradient(g, $v)
println("Using custom likelihood\n")
@btime gradient(g_manual_f, $v)
println("Using custom accumulation as well\n")
@btime gradient(g_manual, $v)
println("Using custom accumulation with callbacks\n")
@btime gradient(g_callbacks, $v)

# I am getting it about a > 200x faster for the custom accumulation than the basics
# The callbacks are roughly the same speed as the direct likelihood calls
jlperla commented 2 years ago

Closing since the answer on the utility of a specialized unit normal in the iteration is resolved.