Open KrisThielemans opened 4 years ago
@jakobsj I mistyped your id above.
For non-affine operators, the gradient depends on the current "input", hence the set_input
suggestion above. However, then it makes sense to have
void set_input(x)
void forward()
void backward(y)
The advantage of this is that we nearly always will want to call backward
after forward
. This way, the output of forward
could be stored, such that it doesn't have to be recomputed for backward
.
Example class interface (with somewhat different naming) is discussed in #739 with an example at https://github.com/SyneRBI/SIRF/blob/add_DeformationModel/src/Registration/cReg/include/sirf/Reg/DeformationModel.h
Generic test for the backward
operation (using numerical gradients)
operator.set_input(input)
result1=operator.backward(out_tilde)
out=model.forward()
result2=input.clone()
for v # cycle over all input elements
input_shifted=input.clone();
input_shifted[v] += epsilon;
model.set_input(input_shifted)
d_out= (operator.forward() – out)/epsilon
result2[v] =hermitian_product(d_out, out_tilde)
note: for complex variables, hermitian_product
needs to be sum(conj(d_out)*out_tilde)
This is obviously going to be very slow (for a domain of size N, it needs N forward
operations).
For the mathematicians, this assumes that domain and range are vector spaces. If they're not, we'd need to think about differentiable manifolds...
Once we have #737 it would make a lot of sense to decide on a similar base class for non-linear operators. We could have
Operator -> DifferentiableOperator -> LinearOperator
.Operator
would haveforward
,DifferentiableOperator
would havebackward
, corresponding toi.e. multiply with the transpose of the Jacobian determinant (i.e. matrix of partial derivates) at the current
input
(as it depends on theinput
). Using the above signature implies we have aset_input
.LinearOperator
would haveadjoint
, which would just bebackward(y)
, anddirect
which would beforward(y)
.For the fun of it (?) we cold have a
Function
(with output a scalar), wheregradient = backward(1)
.@paskino @epapoutsellis jakobsj do you have this stuff for
CIL::Operator
. Could you point us to it? Maybe you have an opinion?