cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
95 stars 28 forks source link

[Bug] Indexing `ConstantMulLinearOperator` with a `SumBatchLinearOperator` base operator #25

Closed j-wilson closed 1 year ago

j-wilson commented 2 years ago

🐛 Bug

To reproduce

A = ops.DenseLinearOperator(rand(4, 3, 2, 2))
B = ops.SumBatchLinearOperator(A, block_dim=-3)
C = ops.ConstantMulLinearOperator(B, rand([]))
C[:, -1:, :].to_dense()
The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-e4d937430d51> in <module>
----> 1 C[:, -1:, :].to_dense()
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/sum_batch_linear_operator.py in to_dense(self)
     59 
     60     def to_dense(self):
---> 61         return self.base_linear_op.to_dense().sum(dim=-3)  # BlockLinearOperators always use dim3 for the block_dim
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/utils/memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 
/mnt/xarfuse/uid-22150/d091ed77-seed-nspid4026533386_cgpid5352405-ns-4026533383/linear_operator/operators/constant_mul_linear_operator.py in to_dense(self)
    164     def to_dense(self):
    165         res = self.base_linear_op.to_dense()
--> 166         return res * self.expanded_constant
    167 
    168     @cached(name="root_decomposition")
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1

Expected Behavior

The code is expected to behave in the same way its dense analogue would.

j-wilson commented 2 years ago

The problem seems to stem from ConstantMulLinearOperator._getitem. The following appears to work, but I am not sure what it's runtime profile looks like in comparison to the existing implementation. We still index into base_linear_op and constant directly. If I'm not mistaken, this new version may be faster (in the particular case considered here) since we now multiply instance: SumBatchLinearOperator by constant rather than instance.base_linear_op.

    def _getitem(self, row_index, col_index, *batch_indices):
        # NOTE TO FUTURE SELF:
        # This custom __getitem__ is actually very important!
        # It prevents constructing an InterpolatedLinearOperator when one isn't needed
        # This affects runtimes by up to 5x on simple exact GPs
        # Run __getitem__ on the base_linear_op and the constant
        base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices)
        constant = self._constant.expand(self.batch_shape)[batch_indices]
        return type(self)(base_linear_op=base_linear_op, constant=constant)
gpleiss commented 2 years ago

Your fix seems reasonable, and I also suspect that it is faster :) Want to throw up a PR for this?

JonathanWenger commented 1 year ago

Fixed by #37.