TomographicImaging / CIL

A versatile python framework for tomographic imaging
https://tomographicimaging.github.io/CIL/
Apache License 2.0
94 stars 41 forks source link

SIRT update and objective methods #1514

Closed epapoutsellis closed 6 months ago

epapoutsellis commented 11 months ago

The update method of SIRT

https://github.com/TomographicImaging/CIL/blob/fb95d17772aea6d6e1f9b0e3f5642afafebc0d1e/Wrappers/Python/cil/optimisation/algorithms/SIRT.py#L193-L198

assumes $\frac{1}{2}||Ax-d||^{2}$ but in the objective method the

$||Ax-d||^{2}$ is computed

https://github.com/TomographicImaging/CIL/blob/fb95d17772aea6d6e1f9b0e3f5642afafebc0d1e/Wrappers/Python/cil/optimisation/algorithms/SIRT.py#L212

In addition, if another constraint is passed, i.e., TV this is considered in the update method but not in the update_objective.

MargaretDuff commented 8 months ago

Hi @epapoutsellis! Thanks for this. I have been taking a look this afternoon - do you have any examples of using TV with SIRT? I am trying to compare with FISTA but the SIRT algorithm seems to be producing just zeros:

data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
ig = data.geometry
A=IdentityOperator(ig)
constraint=TotalVariation()
initial=ig.allocate('random', seed=5)
sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=100, constraint=constraint)
sirt.run(100, verbose=2)
f=LeastSquares(A,data, c=0.5)
fista=FISTA(initial=initial,f=f, g=constraint, max_iteration=100)
fista.run(100, verbose=2)
self.assertNumpyArrayAlmostEqual(fista.x.as_array(), sirt.x.as_array())
self.assertAlmostEqual(fista.loss[-1], sirt.loss[-1])
MargaretDuff commented 8 months ago
data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
ig = data.geometry
A=IdentityOperator(ig)
constraint=TotalVariation()
initial=ig.allocate('random', seed=5)
sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=100, constraint=constraint)
sirt.run(100, verbose=2)
f=LeastSquares(A,data, c=0.5)
fista=FISTA(initial=initial,f=f, g=constraint, max_iteration=100)
fista.run(100, verbose=2)
self.assertNumpyArrayAlmostEqual(fista.x.as_array(), sirt.x.as_array())
self.assertAlmostEqual(fista.loss[-1], sirt.loss[-1])

I think this test is failing because of #1650

epapoutsellis commented 8 months ago

Yes, just tested it. It was working at some point. Even with ISTA + WeightedLS + Preconditioner fails, so I think is due to #1650

epapoutsellis commented 8 months ago

But it does work with FGP_TV

alpha = 1e-6
G = alpha * FGP_TV(max_iteration = 100, device="gpu") 
# G = alpha * TotalVariation(max_iteration = 100) 

sirt = SIRT(initial = initial, operator = A, data = data2D, 
            update_objective_interval = 20, constraint = G, 
            max_iteration = 200)
# sirt.fix_weights()
sirt.run(verbose=1)
epapoutsellis commented 8 months ago

One way to unittest is to use a large reg parameter since in this case we know analytically what the solution would be $$\frac{1}{|\Omega|}\int_{\Omega} data$$

And for the above example everything looks correct with TotalVariation and TotalVariation with warm start. But the objective values for SIRT are wrong because constraint is not used.

from cil.utilities import dataexample
from cil.optimisation.operators import IdentityOperator
from cil.optimisation.functions import L2NormSquared, LeastSquares, TotalVariation
from cil.plugins.ccpi_regularisation.functions import FGP_TV
from cil.optimisation.algorithms import PDHG, FISTA, SIRT
import matplotlib.pyplot as plt
import numpy as np

data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
ig = data.geometry
initial = ig.allocate("random", seed=5)
A = IdentityOperator(ig)
F = 0.5*L2NormSquared(b=data)
# G = 1e6*FGP_TV(max_iteration=1000, device="gpu")
G = 1e6*TotalVariation(max_iteration=1000)

pdhg = PDHG(f = F, g=G, operator = A, update_objective_interval=5, max_iteration=20)
pdhg.run(verbose=1)

f = LeastSquares( A, data, c = 0.5)
fista = FISTA(initial=initial, f=f, g=G, max_iteration=20, update_objective_interval=5)
fista.run(verbose=1)

sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=20, constraint=G, update_objective_interval=5)
sirt.run(verbose=1)

G = 1e6*TotalVariationNew(max_iteration = 100, lower=0., warm_start=True) 
sirt1 = SIRT(initial = initial, operator=A, data=data, max_iteration=20, constraint=G, update_objective_interval=5)
sirt1.run(verbose=1)

print("Solution for large reg parameter should converge to {}".format(np.mean(data)))

plt.figure()
plt.plot(pdhg.solution.array[64,:], label="PDHG")
plt.plot(fista.solution.array[64,:], label="FISTA")
plt.plot(sirt.solution.array[64,:], label="SIRT")
plt.plot(sirt1.solution.array[64,:], label="SIRT-warm start")
plt.legend()
plt.show()

np.testing.assert_allclose(fista.solution.array, sirt.solution.array, atol=1e-4)
np.testing.assert_allclose(pdhg.solution.array, sirt.solution.array, atol=1e-4)
Iter   Max Iter  Time(s)/Iter            Objective
        0         20         0.000          1.79200e+03
        5         20         2.314          4.05011e+06
       10         20         2.256          4.17628e+06
       15         20         2.296          4.18019e+06
       20         20         2.347          4.18034e+06
-------------------------------------------------------
Stop criterion has been reached.

     Iter   Max Iter  Time(s)/Iter            Objective
        0         20         0.000          8.38525e+09
        5         20         2.478          4.18053e+06
       10         20         2.406          4.18053e+06
       15         20         2.421          4.18053e+06
       20         20         2.395          4.18053e+06
-------------------------------------------------------
Stop criterion has been reached.

     Iter   Max Iter  Time(s)/Iter            Objective
        0         20         0.000          1.79200e+03
        5         20         2.425          1.30913e+03
       10         20         2.396          1.30913e+03
       15         20         2.385          1.30913e+03
       20         20         2.351          1.30913e+03
-------------------------------------------------------
Stop criterion has been reached.

     Iter   Max Iter  Time(s)/Iter            Objective
        0         20         0.000          1.79200e+03
        5         20         0.248          1.13241e+03
       10         20         0.258          1.26455e+03
       15         20         0.249          1.27843e+03
       20         20         0.247          1.27982e+03
-------------------------------------------------------
Stop criterion has been reached.

Solution for large reg parameter should converge to 0.25
Screenshot 2024-01-18 at 11 25 33
epapoutsellis commented 7 months ago

@MargaretDuff I was wrong with the update_objective and using the constraint. Since SIRT is basically (projected) gradient descent on a quadratic, the objective is just LeastSquares and we do not need to add the constraint on the objective. Passing it, we assume that we can take the gradient on the TV term for example which is not possible.

Also, from the above experiments, TotalVariation in SIRT works and warm starting is more accurate.

MargaretDuff commented 7 months ago

Thanks @epapoutsellis - I will remove the constraint from the update objective in the PR https://github.com/TomographicImaging/CIL/pull/1658.

Also, from the above experiments, TotalVariation in SIRT works and warm starting is more accurate.

Also in the PR https://github.com/TomographicImaging/CIL/pull/1658 I changed the code so that CIL TotalVariation isn't calculated in place, so it should work and I added a unit test to test it. Will check with warm starting

MargaretDuff commented 7 months ago

Hi Vaggelis, taking a look, I am confused when you say that:

Also, from the above experiments, TotalVariation in SIRT works and warm starting is more accurate.

In the warm start example you show above, the solution is Zeros, which is different to FISTA and PDHG? Why do you say it is "more accurate"

epapoutsellis commented 7 months ago

All the solutions for all algorithms should be close to np.mean(data)=0.25. You can test it with CIL_CVXPY

import cvxpy
u_cvx = cvxpy.Variable(data.shape)
fidelity = 0.5*cvxpy.sum_squares(u_cvx - data.array)   
regulariser = 1e6*tv(u_cvx) 
obj_primal =  cvxpy.Minimize( regulariser +  fidelity)
prob_primal = cvxpy.Problem(obj_primal, constraints = [])
primal_tv = prob_primal.solve(verbose = False, solver = cvxpy.MOSEK)
print(np.min(u_cvx.value), np.max(u_cvx.value))
0.24999999997841021 0.2499999999784153

do you run with current master?

MargaretDuff commented 7 months ago

Thanks for your help Vaggelis! Have added in the unit test:

`def test_SIRT_with_TV_warm_start(self):
    data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
    ig = data.geometry
    A=IdentityOperator(ig)
    constraint=1e6*TotalVariation(warm_start=True, max_iteration=100)
    initial=ig.allocate('random', seed=5)
    sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=150, constraint=constraint)
    sirt.run(25, verbose=0)

    self.assertNumpyArrayAlmostEqual(sirt.x.as_array(), ig.allocate(0.25).as_array(),3)`