TomographicImaging / CIL

A versatile python framework for tomographic imaging
https://tomographicimaging.github.io/CIL/
Apache License 2.0
94 stars 41 forks source link

SIRF and CIL Optimisation #1522

Closed epapoutsellis closed 5 months ago

epapoutsellis commented 11 months ago

In SIRF (3.5.0) new functions were added. There are children of the Prior base class and have methods such as __call__ and gradient. So far there are

Please note that there also Objective functions/classes which are CIL friendly.

In practice, all of the above can and should be used by the CIL optimisation module. For (Proximal) Gradient Algorithms that we have in CIL, e.g., GD, ISTA, FISTA, this is not possible. The reason is because we can combine a CIL Function with a SIRF Prior although they have common class methods.

For example if we want to solve the following problem

$$x^{k+1} = P{{x{k}>0}}(x{k} - \alpha{k} D(x{k}) \nabla f(x{k}))$$

where f is the sum of a CIL (KullbackLeibler) function and one of the above prior this is not possible. An error is raised in the in the __add__ method in Function class

https://github.com/TomographicImaging/CIL/blob/c3dd1b6cbd5ee95571a52bbbe34b7ad593ce15f9/Wrappers/Python/cil/optimisation/functions/Function.py#L118-L132

after passing all the if/else method from the add method in the SumFunction class.

However, if we add a flag or staticmethod in the Function Class

    @staticmethod
    def isfunction(function):
        attrs = ["__call__", "gradient"] # cover cases for SIRF(Prior)Functions, and RegTk
        return all(hasattr(function, attr) for attr in attrs)

and update

https://github.com/TomographicImaging/CIL/blob/c3dd1b6cbd5ee95571a52bbbe34b7ad593ce15f9/Wrappers/Python/cil/optimisation/functions/Function.py#L126-L128

with

        if isinstance(other, Function) or self.isfunction(other):
            return SumFunction(self, other)

the above problem can be solved for any given prior. The figure below is an attempt to reproduce some results from https://pubmed.ncbi.nlm.nih.gov/36044488/ @KrisThielemans , @zeljkozeljko

SIRF_bench

KrisThielemans commented 11 months ago

An alternative would be to derive CIL Functions for each prior, but the above solution is much neater!

KrisThielemans commented 6 months ago

In fact, I think the Pythonic way (certainly the easiest way) is not to check at all. If the function or its gradient gets called, it'll be fine, otherwise it'll throw a (somewhat cryptic) error. I suppose we could catch that error in the __call__ or gradient implementation of SumFunction.

I think it also resolves the issue of the other methods of Function. Do we insist they all need to be there? That'd make life very difficult. So, if they get called, we'll get an error anyway.

So, I suggest

if isinstance(other, (SumScalarFunction, ConstantFunction, Number)): 
     return SumScalarFunction(self, other) 
return Sumfunction(self, other)

(That is, if you really need SumScalarFunction, which I didn't check).

MargaretDuff commented 6 months ago

I think we discussed changing the lines to:

       if isinstance(other, (SumScalarFunction, ConstantFunction, Number)): 
           return SumScalarFunction(self, other) 
       else: 
           return SumFunction(self, other)