JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
266 stars 32 forks source link

Add lazy kronecker product for matrix kernels, if Kronecker.jl is loaded #364

Closed Crown421 closed 3 years ago

Crown421 commented 3 years ago

Summary Following #354 this PR adds lazy Kronecker products for the IndependentMOKernel and the IntrinsicCoregionMOKernel via an optional dependency on Kronecker.jl. Also includes comments and thoughts from the previous PR.

Proposed changes

What alternatives have you considered? The name for kronecker_kernelmatrix is perhaps not ideal, maybe an additional mo suffix/prefix is needed. I think conflating this with kernelkronmat would be confusing.

I have also considered making the _mo_output_covariance function apply also for the regular kernelmatrix, but this may cause conflicts with #363 , and would have to be done once both are resolved.

Breaking changes None.

st-- commented 3 years ago

I think conflating this with kernelkronmat would be confusing.

Having looked at kernelkronmat (and despaired a bit at its lack of documentation), I agree that conflating the two might be confusing (for one, they have quite different signatures and use-cases). kernelkronmat is for when I want to evaluate a (1D) kernel on a multi-dimensional grid. It's actually more limited than it could be, e.g. I could get a kronecker matrix also for any product kernel where each component applies to exactly one of the dimensions...

However, for the MO use-case, the calling is exactly the same as for plain kernelmatrix - you just get a more efficient object back. Would there be any reason not to just have kernelmatrix return a Kronecker object, if Kronecker.jl is loaded? This relates to https://github.com/SebastianAment/CovarianceFunctions.jl/issues/2 as well.

Crown421 commented 3 years ago

However, for the MO use-case, the calling is exactly the same as for plain kernelmatrix - you just get a more efficient object back. Would there be any reason not to just have kernelmatrix return a Kronecker object, if Kronecker.jl is loaded?

This was my preferred option as well, I just had some issues with getting the function to overwrite correctly/ thought about making it more explicit to the user.

However, I was able to sort out the code issues that I had, and have now made some changes that clean up the code substantially I think.

I think what is missing though is clear documentation about this., so I will add that soon.

devmotion commented 3 years ago

The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)

I assume this is caused by the hardcoded Eye{Bool} (before it was Eye{eltype(Kfeatures)}) since booleans are not differentiable in ChainRules (and hence also Zygote).

Crown421 commented 3 years ago

The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)

I assume this is caused by the hardcoded Eye{Bool} (before it was Eye{eltype(Kfeatures)}) since booleans are not differentiable in ChainRules (and hence also Zygote).

This is very odd, because previously the tests passed, and I didn't think that I changed anything that would affect this.

willtebbutt commented 3 years ago

Looking at the error, I think the optimal fix is probably just to make the pullbacks defined for pairwise of Delta around here accept any type, rather than just AbstractMatrixs.

My reasoning is as follows: based on the CI logs that @devmotion linked, it looks like a ZeroTangent is somehow making its way into the pairwise_pullback, which means that a gradient has (correctly) been dropped somewhere earlier in the reverse pass. Usually I would expect Zygote to pick up on this and to never call the pullback, but it seems like that's not happening here for some reason and we need to handle the ZeroTangent manually.

edit: just tried this locally and can confirm that it works.

Crown421 commented 3 years ago

I just had a look at the successful test from a week (or so) ago (https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/runs/3613631221), and while I am not 100% sure I am looking at the right things I see that Zygote has gone from 0.6.21 to 0.6.22 (this happened 13h ago in fact, https://github.com/FluxML/Zygote.jl/releases). One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?

willtebbutt commented 3 years ago

One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?

Yeah, that seems like a likely culprit, specfically https://github.com/FluxML/Zygote.jl/pull/1044 . It's 100% a bug fix though, so we've done that classic thing whereby KF only works properly because of a bug in Zygote.

Tbh, the rules in question should probably all be declared no-ops from an AD perspective using @non_differentiable, which will also fix the problem (I suspect that they were written before we had @non_differentiable available to us) but widening the set of acceptable cotangnets from AbstractMatrix to Any will solve the problem for now.

Crown421 commented 3 years ago

Tbh, the rules in question should probably all be declared no-ops from an AD perspective using @non_differentiable, which will also fix the problem (I suspect that they were written before we had @non_differentiable available to us) but widening the set of acceptable cotangnets from AbstractMatrix to Any will solve the problem for now.

After some bad handling of of git on my part, I have now made some changes, I hope I understood correctly what you suggested.

willtebbutt commented 3 years ago

Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.

Crown421 commented 3 years ago

Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.

Ok, tests are passing now. I should be ready to merge now I think (I don't have permission to merge).

willtebbutt commented 3 years ago

Great. @Crown421 I've invited you to join the org. Please just make sure to have read the first bit of the ColPrac 🙂 if you've not already.

DhairyaLGandhi commented 3 years ago

This doesn't seem great for Zygote flexibility wise. It shouldn't be returning ChainRules types since most rules shouldn't be written to handle them. We should convert these to Zygote friendly types where something like this happens.

willtebbutt commented 3 years ago

Sorry @DhairyaLGandhi I don't follow. Which bit in particular are you referring to?

DhairyaLGandhi commented 3 years ago

Needing to have adjoints be aware of (/explicitly handle) ChainRules's types, I mean.

willtebbutt commented 3 years ago

Oh, but we're just writing rrules, so we should only have to worry about ChainRules types, no?