JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
266 stars 32 forks source link

MaternKernel AD #450

Open willtebbutt opened 2 years ago

willtebbutt commented 2 years ago

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/425 makes it possible to AD through the MaternKernel by dropping the derivative w.r.t. \nu. We tried ensuring that it returns a NotImplemented, but Zygote doesn't appear to handle it properly, returning a nothing, rather than NotImplemented.

devmotion commented 2 years ago

Isn't that the convention in Zygote and expected? In Zygote every undefined gradient is nothing, so it seems that it would quite breaking if in some cases this convention would be violated.

To me it doesn't seem like a Zygote issue and rather nothing seems to be the correct result (regardless of whether one thinks that this nothing convention is a good thing or not - personally I think it's not :smile:).

willtebbutt commented 2 years ago

I'm okay (although disappointed) that this is what Zygote is doing. I think it's good to keep a record of this behaviour, so that we can a) point disgruntled users towards it, and b) point Zygote devs towards it if we feel so inclined.

devmotion commented 2 years ago

Ideally, the convention in Zygote would be changed but I don't think this will happen anytime soon (and possibly never).

But it's still not clear to me why Zygote being surprising here should affect if we return NotImplented or NoTangent. Other CR ADs hopefully handle NotImplemented correctly, and in these cases the choice matters and the implementation in #425 seems wrong.

devmotion commented 2 years ago

I'll take back my statement above. Apparently Zygote returns NotImplemented: https://github.com/FluxML/Zygote.jl/issues/1204#issuecomment-1098460916

Possibly the issue in #425 was that the pullback was incorrect (missing NoTangent for the function itself) - but IIRC Zygote just uses nothing for missing derivatives.