Open m-julian opened 1 year ago
Hi @m-julian I've found the bug (on our end) and I'm hoping to put up a PR to fix later today or tomorrow.
Thanks, looking forward to it!
Hi @gpleiss, when are you planning to put up the PR? I am happy to test out the fix and check gradients are computed correctly.
Sorry, I got a bit backlogged on this PR. I'll try to have something up on Friday or this weekend.
This bug will be fixed in the next LinearOperator release.
🐛 Bug
I have implemented a KeOps periodic kernel in https://github.com/cornellius-gp/gpytorch/pull/2296 , however the
raw_lengthscale
parameter does not have gradients computed (see https://github.com/cornellius-gp/gpytorch/pull/2296#issuecomment-1466327296 ). I have managed to track the issue to the matrix multiplication implemented inLinearOperator.matmul
, see Expected Behavior. This matrix multiplication is called when doingLinearOperator.sum
(as in the example below).To reproduce
Code snippet to reproduce
Stack trace/error message
where
k.raw_lengthscale
is causing the issue.Expected Behavior
Compute the gradients for both
k.raw_lengthscale
andk.raw_period_length
. The exact place where the issue occurs is here https://github.com/cornellius-gp/linear_operator/blob/92f7e3332664f8b1e2ec41552f00f163ab28112e/linear_operator/operators/_linear_operator.py#L2366 which calls theLinearOperator.matmul
that loses track of the gradients. I am summing across the columns in the example, but the same issue occurs if summing across the rows. Adding the following checkgives gradients for both
raw_lengthscale
andraw_period_length
as the customMatmul
is never called.This is probably not the best solution, perhaps the
Matmul
forward/backward methods can be changed, so the gradients are computed correctly?As a check, the same numbers are returned if the normal periodic kernel is used
System information
Please complete the following information: linear_operator version: 0.3.0 torch version: 1.13.1+cu117
Additional context
Add any other context about the problem here.