JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Make pullback error for ColVecs and RowVecs a bit more informative #523

Closed torfjelde closed 11 months ago

torfjelde commented 11 months ago

The adjoint defined for ColVecs and RowVecs indicates that the cause of the error might be something internal to KernelFunctions.jl. But other packages are also making use of ColVecs and RowVecs, e.g. AbstractGPs.jl, which in turn means that issues often encountered in practice are not related to what the adjoint-error indicates, e.g. https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/344

I ran into this yesterday and because of the error message I spent a fair bit of time looking for bugs in the kernel-related code rather than looking elsewhere. Hence the PR:)

My wording in the error message can probably do with an improvement?

EDIT: Oh and why does this adjoint definition exist? Why not just pull back the vector of vectors to a matrix? AFAIK the adjoint is still valid?

codecov[bot] commented 11 months ago

Codecov Report

All modified lines are covered by tests :white_check_mark:

Files Coverage Δ
src/chainrules.jl 45.67% <ø> (-23.46%) :arrow_down:

... and 22 files with indirect coverage changes

:loudspeaker: Thoughts on this report? Let us know!.

devmotion commented 11 months ago

EDIT: Oh and why does this adjoint definition exist? Why not just pull back the vector of vectors to a matrix? AFAIK the adjoint is still valid?

To avoid surprisingly slow code. ColVecs and RowVecs wrap a concatenated matrix of the vectors and usually it is faster to work with that underlying matrix instead of constructing a vector of vectors.

The code was introduced in #84 initially (moved to ChainRules in #208) but as it was copied from Stheno.jl maybe @willtebbutt has some additional comments.

st-- commented 11 months ago

Looks good to me, can you also bump the patch version?

torfjelde commented 11 months ago

To avoid surprisingly slow code.

That's what I suspected, but is it worth it given how difficult it can be to debug these adjoint issues for most users? :grimacing:

devmotion commented 11 months ago

Difficult to debug implies that probably it would be even more difficult to notice the problem and the cause for the slow performance without the error 😛

torfjelde commented 11 months ago

True! But KernelFunctions.jl generally supports these slow-paths given that it's all defined on AbstractVector, no? And given that user-defined methods can easily touch this, e.g. through AbstractGPs.CustomMean, then this error breaks the entire call rather than just having the user-defined method be slow.

torfjelde commented 11 months ago

I think the error makes complete sense in the scenario where all usages of ColVecs and RowVecs are internal, but that it becomes somewhat more nuanced when it can interact with external code.

torfjelde commented 11 months ago

For example, the issue referenced above, is caused because we end up silently taking the "slow path", but we're doing this because there exists a default implementation for AbstractGPs.mean_vector. So I guess my question is then boiled down to, why allow this slow default method to be hit but not the slow pull back?

EDIT: I realize AbstractGPs is a different package, but since it's under the same org + probably done by the same people, I'm guessing the decisions are somewhat related:)

devmotion commented 11 months ago

Without looking into all details, if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs. But of course we want to be as generic as possible so we define methods for AbstractVector - we just always want to take the optimal path for ColVecs/RowVecs.

willtebbutt commented 11 months ago

Without looking into all details, if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs. But of course we want to be as generic as possible so we define methods for AbstractVector - we just always want to take the optimal path for ColVecs/RowVecs.

I agree with this. I can definitely see you point @torfjelde , in that if a slow path is taken, it might be nice for the code not to fall over. Moreover, if the forwards-pass takes a slow path, then I am completely fine with the reverse-pass also taking a slow path. What I was primarily trying to guard against was the forwards-pass taking the fast path, and the pullback somehow hitting the slow path (I believe that this bit me several times when I hit a piratic rrule).

I'll add another comment to the PR, as I think we could clarify this even further.

torfjelde commented 11 months ago

if AbstractGPs takes a slow path, that's a bug in AbstractGPs that should be fixed there. No method in KernelFunctions should support slow paths for ColVecs and RowVecs

Ah gotcha, then it makes sense :+1:

What I was primarily trying to guard against was the forwards-pass taking the fast path, and the pullback somehow hitting the slow path (I believe that this bit me several times when I hit a piratic rrule).

Yeah I can see how an incorrect rrule could cause issues, so it def seems sensible :+1: To me the confusion was mainly when a forward-pass hits the slow path, then we should be happy with the reverse also taking the slow path.

But so this is in general a "restriction" with ColVecs and RowVecs then? You're not allowed to "automatically" AD through "slow paths", such as the example from AbstractGPs?

EDIT: As in, without either defining an explicit overload for the forward-pass or defining a custom adjoint.

willtebbutt commented 11 months ago

But so this is in general a "restriction" with ColVecs and RowVecs then? You're not allowed to "automatically" AD through "slow paths", such as the example from AbstractGPs?

Depends on the kind of slow path. If when you AD the slow path, you produce a Vector{Vector{T}} cotangent for a ColVecs/RowVecs, then no. If it's a slow path, but you get the right type, it'll work fine.

torfjelde commented 11 months ago

If when you AD the slow path, you produce a Vector{Vector{T}} cotangent for a ColVecs/RowVecs, then no. If it's a slow path, but you get the right type, it'll work fine.

With "slow path" I mean when a method that works with AbstractVector or similar receives a ColVecs, e.g. map. In these scenarios, the evaluation/forward pass is allowed and works just fine, even though you're hitting a slow path, but (specifically reverse-mode) AD is not allowed. Have I understood correctly?

willtebbutt commented 11 months ago

With "slow path" I mean when a method that works with AbstractVector or similar receives a ColVecs, e.g. map. In these scenarios, the evaluation/forward pass is allowed and works just fine, even though you're hitting a slow path, but (specifically reverse-mode) AD is not allowed. Have I understood correctly?

Kind of. I'm just saying it's more specific than that in that the reverse-pass must return a Vector{Vector{T}} rather than a Tangent{ColVecs{T}} or whatever. It happens to be the case that this tends to align with when the forwards-pass does something slow and annoying, but you could imagine situations in which this isn't the case. e.g. map(sum, ColVecs(randn(10, 20))) isn't going to be particularly slow. My point is that it's not so much to do with slowness, as it is type errors.

torfjelde commented 11 months ago

It happens to be the case that this tends to align with when the forwards-pass does something slow and annoying, but you could imagine situations in which this isn't the case. e.g. map(sum, ColVecs(randn(10, 20))) isn't going to be particularly slow. My point is that it's not so much to do with slowness, as it is type errors.

Sure, that I'm with; I was referring to "slow paths" because I thought that was the original motivation as to why this rrule was implemented as is.

But is it understandable that it's somewhat confusing that something like map(sum, ColVecs(randn(10, 20))) is allowed but computing the pullback of this is not?

willtebbutt commented 11 months ago

But is it understandable that it's somewhat confusing that something like map(sum, ColVecs(randn(10, 20))) is allowed but computing the pullback of this is not?

It's entirely understandable, but I felt (still feel) that it's the lesser of two evils -- it's opting for a loud error when performance is bad (read: catastrophic) on the reverse-pass, rather than allowing the code to run very slowly, but eventually produce an answer.

torfjelde commented 11 months ago

Gotcha, gotcha :+1:

Thanks for explaining the decision! I'm curious about this stuff partially because of the discussions we've had regarding more general batching structures in the past.

willtebbutt commented 11 months ago

Something has broken with AD 🤦 . It's clearly not this PR's fault, so I'm going to merge and tag a release anyway.