Closed theogf closed 3 years ago
Right. So what's happening is on the reverse-pass, the gradient w.r.t. the output of
_map(κ.transform, x)
found here is a Vector{Vector{Float64}}
, which in turn hits the adjoint for the RowVecs
ctor found here.
When I first wrote the ColVecs
stuff in Stheno, I actively wanted to avoid ever accidentally hitting the slow-path involving pulling out individual elements of a ColVecs
, so made the adjoint error if this ever happened (i.e. it got an actually vector-of-vectors as a derivative).
This isn't an issue in Stheno, which suggests that we're missing / have opted not to implement something here. I believe that the issue is that we've not implemented a specialised version of kerneldiagmatix
for SimpleKernel
s. Since all SimpleKernel
s rely on a metric
, for which colwise
exists, we can do this as something like
function kerneldiagmatrix(k::SimpleKernel, x::AbstractVector)
return map(d -> kappa(κ, d), pairwise(metric(κ), x))
end
function kerneldiagmatrix(k::SimpleKernel, x::AbstractVector, y::AbstractVector)
return map(d -> kappa(κ, d), colwise(metric(κ), x, y))
end
with additional error checking. This should both get rid of this issue and improve performance of kerneldiagmatrix
for SimpleKernel
s.
Does this seem reasonable @theogf ?
It's also a bit surprising that this wasn't picked up by our tests, as it suggests that we're not testing some pretty basic kernel combinations properly.
If your solution works I am all for it! Did you try it out?
Zygote currently fails to differentiate through
kerneldiagmatrix
when given aRowVecs
or aColVecs
MWE :