import linear_operator
import torch
class DiagLinearOperator(linear_operator.LinearOperator):
r"""
A LinearOperator representing a diagonal matrix.
"""
def __init__(self, diag):
# diag: the vector that defines the diagonal of the matrix
self.diag = diag
def _matmul(self, v):
return self.diag.unsqueeze(-1) * v
def _size(self):
return torch.Size([*self.diag.shape, self.diag.size(-1)])
def _transpose_nonbatch(self):
return self # Diagonal matrices are symmetric
# this function is optional, but it will accelerate computation
def logdet(self):
return self.diag.log().sum(dim=-1)
# ...
D = DiagLinearOperator(torch.tensor([1., 2., 3.]))
# Represents the matrix
# [[1., 0., 0.],
# [0., 2., 0.],
# [0., 0., 3.]]
torch.matmul(D, torch.tensor([4., 5., 6.]))
# Returns [4., 10., 18.]
Stack trace/error message
Traceback (most recent call last):
File "/home/jagh/codes/ng/src/a.py", line 31, in <module>
torch.matmul(D, torch.tensor([4., 5., 6.]))
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2970, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1839, in matmul
return Matmul.apply(self.representation_tree(), other, *self.representation())
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2072, in representation_tree
return LinearOperatorRepresentationTree(self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py", line 8, in __init__
self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DiagLinearOperator' object has no attribute '_differentiable_kwargs'
Expected Behavior
Snippet should return [4., 10., 18.]
Additional context
I added self._differentiable_kwargs = { some dict }, which seems by pass the problem, but I get another message with self._nondifferentiable_kwargs I don't know how to setup. Did I miss something?
🐛 Bug
To reproduce
I took the snippet from the README
Stack trace/error message
Expected Behavior
Snippet should return
[4., 10., 18.]
Additional context
I added
self._differentiable_kwargs = { some dict }
, which seems by pass the problem, but I get another message withself._nondifferentiable_kwargs
I don't know how to setup. Did I miss something?