TomographicImaging / CIL

A versatile python framework for tomographic imaging
https://tomographicimaging.github.io/CIL/
Apache License 2.0
94 stars 41 forks source link

Callback in Algorithm #1584

Closed epapoutsellis closed 7 months ago

epapoutsellis commented 9 months ago

During the Stochastic Project, I designed the following improvement for the callback used in the run method of our Algorithm class.

To instantiate an Algorithm we mainly pass instances of Functions and Operators. So I thought why not pass another class which will be responsible for the callbacks. At the moment, the callback used in the run method is restricted to iteration, x attributes and get_last_objective() method. https://github.com/TomographicImaging/CIL/blob/0418ed2abd42c085db25aa3d0405b4b16e6591e4/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py#L286-L287 This is not flexible. It basically gives access to the user for only three attributes.

The design that I have is the following lines

https://github.com/epapoutsellis/StochasticCIL/blob/477d51fd94c9d6554ca143b86e4b95a21c60835d/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py#L311-L313

Instead of using function that takes 3 inputs, I pass a Callable Class with acts on the algorithm which in this case is self, i.e., the instantiated algorithm. In this way, the user has access to every attribute/method of the Algorithm class.

What the users can do with it? Anything they want

How I use it for the Stochastic Project?

class ComputeGradient(AlgorithmDiagnostics):

    def __init__(self):
        super(ComputeGradient, self).__init__(verbose = 0)        

    def __call__(self, algo):
        if not hasattr(algo, "gradf"):
            setattr(algo, "gradf", [])                
        algo.gradf.append(algo.f.gradient(algo.get_output()).norm())

and the user can do plt.plot(ista.gradf) to get

Screen Shot 2023-11-22 at 10 06 40

https://github.com/TomographicImaging/CIL/blob/0418ed2abd42c085db25aa3d0405b4b16e6591e4/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py#L83-L90

A Logger class passed to callback can take care of this. Also, wandb has this functionality. @gschramm used Tensorboard here. I had some problems with Tensorboard and its gui.

Finally, a similar design is used for the Proximal Gradient Algorithm, Preconditioner and StepSize classes, but I will create separate issues for them.

Note Atm, all the callbacks are executed every update_objective_interval. This can be easily fixed if we want but is not important in my opinion. Basically, we can compute SSIM every 10 iterations, plot the current iterate every 2 iterations and the .mean of $x_{k}$ every 50 iterations.

casperdcl commented 9 months ago

I guess maybe we should have a new callbacks: Iterable[callable] kwarg (leaving the existing callback: callable kwarg as-is for backwards compatibility)? /CC @paskino

casperdcl commented 8 months ago

Had a quick chat with @paskino @MargaretDuff.

I'm going to follow https://keras.io/api/callbacks design and:

epapoutsellis commented 8 months ago

Are you going to move update_objective_interval from the signature of Algorithm and pass it to Callback?

paskino commented 8 months ago

Discussed with @epapoutsellis @MargaretDuff @jakobsj

It seems that the callback route could be used to define the stopping criterion?

MargaretDuff commented 8 months ago

Discussed with @epapoutsellis @MargaretDuff @jakobsj

It seems that the callback route could be used to define the stopping criterion?

Callbacks acting on the algorithm class can be used to raise StopIteration() and then terminate the iterations. Keras has an EarlyStopping class which stops training when the monitored metric stops improving https://keras.io/api/callbacks/early_stopping/.

There was a bit of a discussion with two options: