Unable to reproduce the AD optimization in the docs due to type error #248

Closed Jarrod-Angove closed 2 months ago

Jarrod-Angove commented 2 months ago

I'm quite new to this package, so please forgive me if this is a mistake on my end, but I've been having some difficulties getting Optim AD methods working with Stheno. Even after directly copying the BFGS example from the docs, I end up with the following error:

ERROR: MethodError: no method matching AbstractGPs.FiniteGP(::Stheno.DerivedGP{Tuple{typeof(+), Stheno.DerivedGP{…}, Stheno.DerivedGP{…}}}, ::Float64)

Closest candidates are:
  AbstractGPs.FiniteGP(::AbstractGPs.AbstractGP, ::AbstractVector, ::AbstractVector{<:Real})
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:13
  AbstractGPs.FiniteGP(::AbstractGPs.AbstractGP, ::AbstractVector, ::Real)
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:19
  AbstractGPs.FiniteGP(::Tf, ::Tx, ::TΣ) where {Tf<:AbstractGPs.AbstractGP, Tx<:(AbstractVector), TΣ}
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:8

This occurs if I copy and past the tutorial in the docs directly into the repl, or if I paste everything as a single function:

function copy_stheno_test()
    l1 = 0.4
    s1 = 0.2
    l2 = 5.0
    s2 = 1.0

    g = @gppp let
    f1 = s1 * stretch(GP(Matern52Kernel()), 1 / l1)
    f2 = s2 * stretch(GP(SEKernel()), 1 / l2)
    f3 = f1 + f2

    x = GPPPInput(:f3, collect(range(-5.0, 5.0; length=100)));
    σ²_n = 0.02;
    fx = g(x, σ²_n);
    y = rand(fx);

    θ = (
        # Short length-scale and small variance.
        l1 = positive(0.4),
        s1 = positive(0.2),

        # Long length-scale and larger variance.
        l2 = positive(5.0),
        s2 = positive(1.0),

        # Observation noise variance -- we'll be learning this as well. Constrained to be
        # at least 1e-3.
        s_noise = positive(0.1, exp, 1e-3),
    θ_flat_init, unflatten = flatten(θ);
    unpack = ParameterHandling.value ∘ unflatten;

    function build_model(θ::NamedTuple)
    return @gppp let
        f1 = θ.s1 * stretch(GP(SEKernel()), 1 / θ.l1)
        f2 = θ.s2 * stretch(GP(SEKernel()), 1 / θ.l2)
        f3 = f1 + f2

    function nlml(θ::NamedTuple)
    f = build_model(θ)
    return -logpdf(f(x, θ.s_noise + 1e-6), y)

    results = Optim.optimize(
        nlml ∘ unpack,
        θ->gradient(nlml ∘ unpack, θ)[1],
        θ_flat_init + randn(length(θ_flat_init)),
    return results

This issue does not occur if I do not provide a gradient to optim, as I believe it defaults to a finite different method to calculate the gradient. Similarly, using NelderMead() allows it to run without issue.

Here is the stack trace:

  [1] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::Type{AbstractGPs.FiniteGP}, ::Stheno.DerivedGP{Tuple{…}}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [3] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [4] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [5] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [6] AbstractGP
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:32 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::Stheno.DerivedGP{Tuple{…}}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] #rrule_via_ad#54
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:260 [inlined]
  [9] rrule_via_ad
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:248 [inlined]
 [10] #723
    @ ./none:0 [inlined]
 [11] iterate
    @ ./generator.jl:47 [inlined]
 [12] collect(itr::Base.Generator{Vector{Float64}, ChainRules.var"#723#728"{Zygote.ZygoteRuleConfig{…}, Stheno.DerivedGP{…}}})
    @ Base ./array.jl:834
 [13] rrule(config::Zygote.ZygoteRuleConfig{…}, ::typeof(sum), f::Stheno.DerivedGP{…}, xs::Vector{…}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Base/mapreduce.jl:102
 [14] rrule
    @ ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Base/mapreduce.jl:76 [inlined]
 [15] rrule(config::Zygote.ZygoteRuleConfig{…}, ::typeof(mean), f::Stheno.DerivedGP{…}, x::Vector{…}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Statistics/statistics.jl:28
 [16] rrule
    @ ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Statistics/statistics.jl:21 [inlined]
 [17] chain_rrule
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:223 [inlined]
 [18] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(mean), ::Stheno.DerivedGP{Tuple{…}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
 [20] mean_and_cov
    @ ~/.julia/packages/AbstractGPs/XejGR/src/abstract_gp.jl:48 [inlined]
 [21] _pullback(::Zygote.Context{false}, ::typeof(mean_and_cov), ::Stheno.DerivedGP{Tuple{…}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] mean_and_cov
    @ ~/.julia/packages/Stheno/ZSwgx/src/gaussian_process_probabilistic_programme.jl:74 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(mean_and_cov), ::Stheno.GaussianProcessProbabilisticProgramme{…}, ::GPPPInput{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [24] mean_and_cov
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:134 [inlined]
 [25] _pullback(ctx::Zygote.Context{…}, f::typeof(mean_and_cov), args::AbstractGPs.FiniteGP{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] logpdf
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:307 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::typeof(logpdf), ::AbstractGPs.FiniteGP{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [28] nlml
    @ ./REPL[31]:3 [inlined]
 [29] _pullback(ctx::Zygote.Context{…}, f::typeof(nlml), args::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] call_composed
    @ ./operators.jl:1044 [inlined]
 [31] #_#103
    @ ./operators.jl:1041 [inlined]
 [32] _pullback(::Zygote.Context{…}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [33] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [34] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [35] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [36] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [37] _pullback(ctx::Zygote.Context{false}, f::ComposedFunction{typeof(nlml), ComposedFunction{…}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [39] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [40] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [41] (::var"#7#8")(θ::Vector{Float64})
    @ Main ./REPL[36]:3
 [42] (::NLSolversBase.var"#gg!#2"{var"#7#8"})(G::Vector{Float64}, x::Vector{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/inplace_factory.jl:21
 [43] (::NLSolversBase.var"#fg!#8"{…})(gx::Vector{…}, x::Vector{…})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/abstract.jl:13
 [44] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
 [45] initial_state(method::BFGS{…}, options::Optim.Options{…}, d::OnceDifferentiable{…}, initial_x::Vector{…})
    @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/first_order/bfgs.jl:94
 [46] optimize
    @ ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/optimize.jl:36 [inlined]
 [47] optimize(f::Function, g::Function, initial_x::Vector{…}, method::BFGS{…}, options::Optim.Options{…}; inplace::Bool, autodiff::Symbol)
    @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/interface.jl:156
 [48] top-level scope
    @ REPL[36]:1
willtebbutt commented 2 months ago

Hmmm I'm really not sure what Zygote is doing here -- it has an annoying habit of breaking when you don't expect it to...

I would suggest trying with Tapir.jl. Could you please see if the following works for you?

using Stheno, Tapir, Random, Optim, ParameterHandling

l1 = 0.4
s1 = 0.2
l2 = 5.0
s2 = 1.0

g = @gppp let
f1 = s1 * stretch(GP(Matern52Kernel()), 1 / l1)
f2 = s2 * stretch(GP(SEKernel()), 1 / l2)
f3 = f1 + f2

x = GPPPInput(:f3, collect(range(-5.0, 5.0; length=100)));
σ²_n = 0.02;
fx = g(x, σ²_n);
y = rand(fx);

θ = (
    # Short length-scale and small variance.
    l1 = positive(0.4),
    s1 = positive(0.2),

    # Long length-scale and larger variance.
    l2 = positive(5.0),
    s2 = positive(1.0),

    # Observation noise variance -- we'll be learning this as well. Constrained to be
    # at least 1e-3.
    s_noise = positive(0.1, exp, 1e-3),
θ_flat_init, unflatten = flatten(θ);

function build_model(θ::NamedTuple)
    return @gppp let
        f1 = θ.s1 * stretch(GP(SEKernel()), 1 / θ.l1)
        f2 = θ.s2 * stretch(GP(SEKernel()), 1 / θ.l2)
        f3 = f1 + f2

function nlml(θ::NamedTuple)
    f = build_model(θ)
    return -logpdf(f(x, θ.s_noise + 1e-6), y)

# Define objective function, check it runs, and compute a gradient to check that works.
obj(x) = nlml(ParameterHandling.value(unflatten(x)))

rule = Tapir.build_rrule(obj, θ_flat_init)
Tapir.value_and_gradient!!(rule, obj, θ_flat_init)

# Run optimisation.
results = Optim.optimize(
    θ->Tapir.value_and_gradient!!(rule, obj, θ)[2][2],
    θ_flat_init + randn(length(θ_flat_init)),
Jarrod-Angove commented 2 months ago

Thanks for the quick response @willtebbutt !

I've run the suggested code but the Tapir.build_rrule function is failing:

julia> rule = Tapir.build_rrule(obj, θ_flat_init)
ERROR: MethodError: no method matching tangent_field_type(::Type{ParameterHandling.var"#unflatten_to_NamedTuple#15"{…}}, ::Int64)
The applicable method may be too new: running in world age 31913, while current world is 34727.

Closest candidates are:
  tangent_field_type(::Type{P}, ::Int64) where P (method too new to be called from this world context.)
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/tangents.jl:282

Edit: I thought it may be an issue with the ParameterHandling package, so I changed theta to a vector and tried Zygote, Tapir, and ForwardDiff again;

θ = [

Zygote still fails with the same type error. Tapir fails similarly, but this time the build_rrule function works and the value_and_gradient!! function fails:

ERROR: MethodError: no method matching tangent_field_type(::Type{Stheno.GaussianProcessProbabilisticProgramme{@NamedTuple{…}}}, ::Int64)
The applicable method may be too new: running in world age 31913, while current world is 34966.

Interestingly, by not using ParameterHandling, the ForwardDiff package is able to compute the gradient:

julia> ForwardDiff.gradient(obj, θ)
5-element Vector{Float64}:
willtebbutt commented 2 months ago

Hmm could you show me the output of Pkg.status(). It might be that you need a clean install. Also, what version of Julia are you on?

Jarrod-Angove commented 2 months ago

Yep... It looks like there is something wrong with my project. I ran the code again in a clean install and it appears to work fine. I have no idea what could be causing this, but I'll try rebuilding everything piece by piece until I find the conflict.

For the sake of completeness: My Julia version is 1.10.4. and here is my Pkg.status():

Project DilPredict v0.1.0
Status `~/Documents/grad_school/thesis/DilPredict/Project.toml`
⌃ [99985d1d] AbstractGPs v0.5.9
  [8bb1440f] DelimitedFiles v1.9.1
  [39dd38d3] Dierckx v0.5.3
  [f6369f11] ForwardDiff v0.10.36
  [033835bb] JLD2 v0.4.48
  [ec8451be] KernelFunctions v0.10.63
  [429524aa] Optim v1.9.4
  [2412ca09] ParameterHandling v0.5.0
  [91a5bcdd] Plots v1.40.4
  [c46f51b8] ProfileView v1.7.2
  [8188c328] Stheno v0.8.2
  [07d77754] Tapir v0.2.20
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics v1.10.0
  [fa267f1f] TOML v1.0.3

Sorry for the trouble!

Edit: For some reason Pkg had been fetching a super outdated version of AbstractGPs... All I needed to do was update :')

willtebbutt commented 2 months ago

No trouble at all -- happy to have been able to help!