Closed Crown421 closed 3 years ago
I think conflating this with kernelkronmat would be confusing.
Having looked at kernelkronmat
(and despaired a bit at its lack of documentation), I agree that conflating the two might be confusing (for one, they have quite different signatures and use-cases). kernelkronmat
is for when I want to evaluate a (1D) kernel on a multi-dimensional grid. It's actually more limited than it could be, e.g. I could get a kronecker matrix also for any product kernel where each component applies to exactly one of the dimensions...
However, for the MO use-case, the calling is exactly the same as for plain kernelmatrix
- you just get a more efficient object back. Would there be any reason not to just have kernelmatrix
return a Kronecker
object, if Kronecker.jl is loaded? This relates to https://github.com/SebastianAment/CovarianceFunctions.jl/issues/2 as well.
However, for the MO use-case, the calling is exactly the same as for plain
kernelmatrix
- you just get a more efficient object back. Would there be any reason not to just havekernelmatrix
return aKronecker
object, if Kronecker.jl is loaded?
This was my preferred option as well, I just had some issues with getting the function to overwrite correctly/ thought about making it more explicit to the user.
However, I was able to sort out the code issues that I had, and have now made some changes that clean up the code substantially I think.
I think what is missing though is clear documentation about this., so I will add that soon.
The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)
I assume this is caused by the hardcoded Eye{Bool}
(before it was Eye{eltype(Kfeatures)}
) since booleans are not differentiable in ChainRules (and hence also Zygote).
The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)
I assume this is caused by the hardcoded
Eye{Bool}
(before it wasEye{eltype(Kfeatures)}
) since booleans are not differentiable in ChainRules (and hence also Zygote).
This is very odd, because previously the tests passed, and I didn't think that I changed anything that would affect this.
Looking at the error, I think the optimal fix is probably just to make the pullbacks defined for pairwise of Delta
around here accept any type, rather than just AbstractMatrix
s.
My reasoning is as follows: based on the CI logs that @devmotion linked, it looks like a ZeroTangent
is somehow making its way into the pairwise_pullback
, which means that a gradient has (correctly) been dropped somewhere earlier in the reverse pass. Usually I would expect Zygote to pick up on this and to never call the pullback, but it seems like that's not happening here for some reason and we need to handle the ZeroTangent
manually.
edit: just tried this locally and can confirm that it works.
I just had a look at the successful test from a week (or so) ago (https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/runs/3613631221), and while I am not 100% sure I am looking at the right things I see that Zygote has gone from 0.6.21 to 0.6.22 (this happened 13h ago in fact, https://github.com/FluxML/Zygote.jl/releases). One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?
One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?
Yeah, that seems like a likely culprit, specfically https://github.com/FluxML/Zygote.jl/pull/1044 . It's 100% a bug fix though, so we've done that classic thing whereby KF only works properly because of a bug in Zygote.
Tbh, the rules in question should probably all be declared no-ops from an AD perspective using @non_differentiable
, which will also fix the problem (I suspect that they were written before we had @non_differentiable
available to us) but widening the set of acceptable cotangnets from AbstractMatrix
to Any
will solve the problem for now.
Tbh, the rules in question should probably all be declared no-ops from an AD perspective using
@non_differentiable
, which will also fix the problem (I suspect that they were written before we had@non_differentiable
available to us) but widening the set of acceptable cotangnets fromAbstractMatrix
toAny
will solve the problem for now.
After some bad handling of of git on my part, I have now made some changes, I hope I understood correctly what you suggested.
Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.
Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.
Ok, tests are passing now. I should be ready to merge now I think (I don't have permission to merge).
Great. @Crown421 I've invited you to join the org. Please just make sure to have read the first bit of the ColPrac 🙂 if you've not already.
This doesn't seem great for Zygote flexibility wise. It shouldn't be returning ChainRules types since most rules shouldn't be written to handle them. We should convert these to Zygote friendly types where something like this happens.
Sorry @DhairyaLGandhi I don't follow. Which bit in particular are you referring to?
Needing to have adjoints be aware of (/explicitly handle) ChainRules's types, I mean.
Oh, but we're just writing rrule
s, so we should only have to worry about ChainRules types, no?
Summary Following #354 this PR adds lazy Kronecker products for the
IndependentMOKernel
and theIntrinsicCoregionMOKernel
via an optional dependency on Kronecker.jl. Also includes comments and thoughts from the previous PR.Proposed changes
kronecker_kernelmatrix
What alternatives have you considered? The name for
kronecker_kernelmatrix
is perhaps not ideal, maybe an additionalmo
suffix/prefix is needed. I think conflating this withkernelkronmat
would be confusing.I have also considered making the
_mo_output_covariance
function apply also for the regularkernelmatrix
, but this may cause conflicts with #363 , and would have to be done once both are resolved.Breaking changes None.