JuliaGaussianProcesses / AbstractGPs.jl

Abstract types and methods for Gaussian Processes.
https://juliagaussianprocesses.github.io/AbstractGPs.jl/dev
Other
218 stars 20 forks source link

Zygote errors with parameterized mean functions and multidimensional input #344

Closed simsurace closed 1 month ago

simsurace commented 1 year ago

When trying to differentiate logpdf or other scalar functions with a parameterized mean function and multidimensional input, there are errors:

using AbstractGPs
using Zygote

pars = [1., 0.]

function build_model(pars)
    a, b = pars
    return GP(x -> a * first(x) + b, SEKernel())
end

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) # works

function test_logpdf2(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf2(pars)
Zygote.gradient(test_logpdf2, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

function test_mean(pars)
    f = build_model(pars)
    x, _ = rand_data_2d(10)
    return sum(mean(f(x, 1e-3)))
end

test_mean(pars)
Zygote.gradient(test_mean, pars) # ERROR: Pullback on AbstractVector{<:AbstractVector}.

function test_post_mean(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    fp = posterior(f(x, 1e-3), y)
    return sum(mean(fp(x, 1e-3)))
end

test_post_mean(pars)
Zygote.gradient(test_post_mean, pars) 
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

Is there a simple fix? The error for test_mean gives a suggestion to overload a kernelmatrix method, but that does not seem to be the issue since we are talking about the mean here. Why does the existing rrule for RowVecs not suffice?

willtebbutt commented 1 year ago

Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons.

To be honest, the simplest solution is going to be to implement AbstractGPs._map_meanfunction for your custom mean function and a RowVecs / ColVecs input, so that you can be sure that it's differentiable. So, something like

function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)

end

etc. @simsurace could you let me know if this solves the problem?

I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement _map_meanfunction if you're using a custom mean function.

simsurace commented 1 year ago

Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue:

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return [f.f.a * first(xi) + f.f.b for xi in x]
end

function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
willtebbutt commented 1 year ago

Ah, sorry, I more mean something like

function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return f.f.a * x.X[1, :] .+ f.f.b
end

so that you're interacting with the underlying matrix.

simsurace commented 1 year ago

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

simsurace commented 1 year ago

P.S. CustomMean currently does not seem to be exported or documented, for that matter.

willtebbutt commented 1 year ago

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

That should indeed work!

simsurace commented 1 year ago

I ended up with a general struct FunctionOfTime{Tf} <: MeanFunction with overloads that map its field over slices. This works. Thanks for the tips!

torfjelde commented 1 year ago

I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing :confused:

Probably worth an entry in the docs + maybe changing the error in KernelFunctions?

willtebbutt commented 1 year ago

I agree that the docs should probably be improved here

willtebbutt commented 1 year ago

We should probably add a note about when you the need to implement mean_vector yourself for CustomMean here and here, and provide an example.

kjrathore commented 2 months ago
pars = [1., 0.]

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(m::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return vec(sum(m.f.(x.X); dims=2))
end

function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)

Hi, @simsurace I did use your code to check working with Zygote.

However I am getting error. Could you please guide me how did you solve this issue? image

simsurace commented 2 months ago

Hi @kjrathore, thanks for reaching out. _map_meanfunction has been removed since then. There is now a public mean_vector that can be overloaded. This should work on the current release of AbstractGPs:

using AbstractGPs, Zygote

pars = [1., 0.]

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs.mean_vector(m::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return vec(sum(m.f.(x.X); dims=2))
end

function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)
kjrathore commented 2 months ago

Thanks @simsurace ! Slight update to this code. need to do import AbstractGPs: mean_vector

Note : "using" and "import" in Julia are not the same. While using brings all exported names from a module into the current namespace, import allows you to extend a function without prefixing it with its module name (https://stackoverflow.com/questions/42888911/function-base-must-be-explicitly-imported-to-be-extended)