Closed simsurace closed 1 month ago
Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons.
To be honest, the simplest solution is going to be to implement AbstractGPs._map_meanfunction
for your custom mean function and a RowVecs
/ ColVecs
input, so that you can be sure that it's differentiable. So, something like
function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)
end
etc. @simsurace could you let me know if this solves the problem?
I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement _map_meanfunction
if you're using a custom mean function.
Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue:
struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return [f.f.a * first(xi) + f.f.b for xi in x]
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
Ah, sorry, I more mean something like
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return f.f.a * x.X[1, :] .+ f.f.b
end
so that you're interacting with the underlying matrix.
Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean
in CustomMean
seems overly complicated. I could make LinearMean <: MeanFunction
and then define _map_meanfunction
accordingly, right?
P.S. CustomMean
currently does not seem to be exported or documented, for that matter.
Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?
That should indeed work!
I ended up with a general struct FunctionOfTime{Tf} <: MeanFunction
with overloads that map its field over slices. This works. Thanks for the tips!
I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing :confused:
Probably worth an entry in the docs + maybe changing the error in KernelFunctions?
I agree that the docs should probably be improved here
pars = [1., 0.]
rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)
struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(m::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return vec(sum(m.f.(x.X); dims=2))
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
function test_logpdf(pars)
f = build_model(pars)
x, y = rand_data(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)
Hi, @simsurace I did use your code to check working with Zygote.
However I am getting error. Could you please guide me how did you solve this issue?
Hi @kjrathore, thanks for reaching out. _map_meanfunction
has been removed since then. There is now a public mean_vector
that can be overloaded. This should work on the current release of AbstractGPs:
using AbstractGPs, Zygote
pars = [1., 0.]
rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)
struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs.mean_vector(m::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return vec(sum(m.f.(x.X); dims=2))
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
function test_logpdf(pars)
f = build_model(pars)
x, y = rand_data(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)
Thanks @simsurace !
Slight update to this code. need to do
import AbstractGPs: mean_vector
Note : "using" and "import" in Julia are not the same. While using brings all exported names from a module into the current namespace, import allows you to extend a function without prefixing it with its module name (https://stackoverflow.com/questions/42888911/function-base-must-be-explicitly-imported-to-be-extended)
When trying to differentiate
logpdf
or other scalar functions with a parameterized mean function and multidimensional input, there are errors:Is there a simple fix? The error for
test_mean
gives a suggestion to overload akernelmatrix
method, but that does not seem to be the issue since we are talking about the mean here. Why does the existingrrule
forRowVecs
not suffice?