Closed j-wilson closed 1 year ago
The problem seems to stem from ConstantMulLinearOperator._getitem
. The following appears to work, but I am not sure what it's runtime profile looks like in comparison to the existing implementation. We still index into base_linear_op
and constant
directly. If I'm not mistaken, this new version may be faster (in the particular case considered here) since we now multiply instance: SumBatchLinearOperator
by constant
rather than instance.base_linear_op
.
def _getitem(self, row_index, col_index, *batch_indices):
# NOTE TO FUTURE SELF:
# This custom __getitem__ is actually very important!
# It prevents constructing an InterpolatedLinearOperator when one isn't needed
# This affects runtimes by up to 5x on simple exact GPs
# Run __getitem__ on the base_linear_op and the constant
base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
constant = self._constant.expand(self.batch_shape)[batch_indices]
return type(self)(base_linear_op=base_linear_op, constant=constant)
Your fix seems reasonable, and I also suspect that it is faster :) Want to throw up a PR for this?
Fixed by #37.
🐛 Bug
To reproduce
Expected Behavior
The code is expected to behave in the same way its dense analogue would.