bankofcanada / ModelBaseEcon.jl

BSD 3-Clause "New" or "Revised" License
26 stars 5 forks source link

Avoid specializing all of ForwardDiff on every equation #37

Closed KristofferC closed 1 year ago

KristofferC commented 1 year ago

ForwardDiff quite aggressively specializes most of its functions on the concrete input function type. This gives a slight performance improvement but it also means that a significant chunk of code has to be compiled for every call to ForwardDiff with a new function.

Previously, for every equation in a model we would call ForwardDiff.gradient with the julia function corresponding to that equation. This would then compile the ForwardDiff functions for all of these julia functions.

Looking at the specializations generated by a model, we see:

GC = ForwardDiff.GradientConfig{FRBUS_VAR.MyTag, Float64, 4, Vector{ForwardDiff.Dual{FRBUS_VAR.MyTag, Float64, 4}}}
MethodInstance for ForwardDiff.vector_mode_dual_eval!(::FRBUS_VAR.EquationEvaluator{:resid_515}, ::GC, ::Vector{Float64})
MethodInstance for ForwardDiff.vector_mode_gradient!(::DiffResults.MutableDiffResult{1, Float64, Tuple{Vector{Float64}}}, ::FRBUS_VAR.EquationEvaluator{:resid_515}, ::Vector{Float64}, ::GC)
MethodInstance for ForwardDiff.vector_mode_dual_eval!(::FRBUS_VAR.EquationEvaluator{:resid_516}, ::GC, ::Vector{Float64})
MethodInstance for ForwardDiff.vector_mode_gradient!(::DiffResults.MutableDiffResult{1, Float64, Tuple{Vector{Float64}}}, ::FRBUS_VAR.EquationEvaluator{:resid_516}, ::Vector{Float64}, ::GC)

which are all identical methods compiled for different equations.

In this PR, we instead "hide" all the concrete functions for every equation between a common "wrapper functions". This means that only one specialization of the ForwardDiff functions gets compiled.

Using the following benchmark script:

unique!(push!(LOAD_PATH, realpath("./models")))
using ModelBaseEcon
using Random # See https://github.com/JuliaLang/julia/pull/48810

@time using FRBUS_VAR

m = FRBUS_VAR.model
nrows = 1 + m.maxlag + m.maxlead
ncols = length(m.allvars)
pt = zeros(nrows, ncols);
@time @eval eval_RJ(pt, m);

using BenchmarkTools
@btime eval_RJ(pt, m);

This PR has the following changes:

So there seems to be about a 10% runtime performance in the eval_RJ call but the latency is drastically reduced.