Open ckolbPTB opened 1 year ago
Suggestions: We should make the operator valid torch.autograd.Function-s
I had a look at MIRTorch but I don't fully understand what torch.autograd requires.
For linear operators they don't use it but give some pseudo-code as guidance: https://github.com/guanhuaw/MIRTorch/blob/f38dc662f521a66796d65710bedba6609904b088/mirtorch/linear/linearmaps.py#L19
For CG it is used but by defining a separate class in addition to their "normal" CG class: https://github.com/guanhuaw/MIRTorch/blob/f38dc662f521a66796d65710bedba6609904b088/mirtorch/alg/cg.py#L4
@fzimmermann89 do you have a better example of what this would look like for our operator classes?
I will need some time to think to form an opinion on this. It would be nice to either (in the common case)
But torch.autograd.Function works by defining a static forward and backward, so we will have to do some wrapping do define operator instances.
We might ignore this for a 0.1 version and first discuss and fix the overall interface of our operators.
Random thoughts:
Some things I like about the MIRTorch approach that we should consider:
We should instead of deriv have an optional "backward" function implementing the vjp, and get it picked up by autograd if implemented. If not implemented for non-linear operators, we will use autograd will be used to calculate it based on the forward. For linear operators, this should be .adjoint. We should a .jacobian, that is either implemented for a non linear operator, or pytorch.autograd.functional.jacobian is used which in turn uses the vjp that might be implemented (or autograd is used). For linear operators, this would return the adjointoperator.
class Operator(ABC, nn.Module): @abstractmethod def forward(data:torch.Tensor) -> tensor class LinearOperator(Operator): @abstractmethod def _adjoint(data:tensor): @property def adjoint(self): operator = deepcopy(self) operator._adjoint, operator.forward = operator.forward, operator._adjoint
Close by PR #104
I reopen this to collect some more thoughts about linear operator and autograd.
Maybe we can add a default adjoint to linearoperator that uses (really inefficient) autograd call to do the adjoint and prints a warning. This could be useful for quick prototyping without having to think about both operations.
we could rename our forward and adjoint in the operator implementation to _forward_impl and _adjoint_impl. And add a forward and adjoint to the linearoperator base class that call a custom autograd.function that sets the backward of the forward as the adjoint_impl and vice-versa instead of having to use autograd. So an operator implementation can either implement forward and adjoint and use autograd or only implement to *_impl have the baseclass do the logic for the backward.
This is for example useful for sparse matrix multiplication. If the forward uses a csr matrix, the backward is a factor of 10 slower than if the backward uses the adjoint which uses a copy of the matrix already converted to csc layout in the init.
Alternativly, we could do the second bullet point for each linearoperator where we think it matters manuelly in the implementation.
We should test for some often used operators if autograd is slower than using the adjoint.
Alternativly , we could do the second bullet point for each linearoperator where we think it matters manuelly in the implementation.
We should test for some often used operators if autograd is slower than using the adjoint.
Create template classes for linear and non-linear operators:
Check list: