odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
372 stars 105 forks source link

algorithm is slowing down (again) #1314

Closed mehrhardt closed 6 years ago

mehrhardt commented 6 years ago

I am having a similar issue as in #1291 but it is also somewhat different. I do observe that after a certain number of iterations the algorithm is slowing down massively. I have tried many things to figure out what is going wrong but with no success. The smallest example where I observe this behaviour is the following. This "minimal example" has two problems. First, it does depend on the "spdhg" package which is not yet part of ODL. Second, it is not really small. However, I was not able to create a smaller example that shows the same behaviour. Also as you can see with my comments, changing the example slightly in almost whatever way resolves the problem ... Do you have any idea of what might go wrong here / what to look at?

import odl.contrib.solvers.spdhg as spdhg
import odl

niter_target = 100
#n = 10   # works perfect!
n = 100
X = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20], shape=[n, n])

geometry = odl.tomo.parallel_beam_geometry(X, num_angles=20, det_shape=20)
G = odl.tomo.RayTransform(X, geometry, impl='astra_cpu')
#G = odl.IdentityOperator(X) # works perfect!
#
Y = G.range
groundtruth = X.one()

#sinogram = 1 * Y.one() # works perfect!
sinogram = 100 * Y.one()
#factors = 1 * Y.one() # works perfect!
factors = 34 * Y.one()

background = 10 * Y.one()
data = factors * sinogram
#f = odl.solvers.L2NormSquared(Y).translated(-background) # works perfect!
f = odl.solvers.KullbackLeibler(Y, data).translated(-background)
A = factors * G

#g = odl.solvers.functional.ConstantFunctional(X, 0) # works perfect!
#g = odl.solvers.L1Norm(X) # works perfect!
g = spdhg.TotalVariationNonNegative(X, alpha=1e-1)
g.prox_options['p'] = None

callback = odl.solvers.CallbackPrintTiming(step=1, cumulative=False)

odl.solvers.pdhg(X.zero(), f, g, A, 1, 1, niter_target, callback=callback)

Output:

Time elapsed = 0.036 s Time elapsed = 0.031 s Time elapsed = 0.031 s Time elapsed = 0.033 s Time elapsed = 0.031 s Time elapsed = 0.031 s Time elapsed = 0.034 s Time elapsed = 0.035 s Time elapsed = 0.031 s Time elapsed = 0.034 s Time elapsed = 0.031 s Time elapsed = 0.032 s Time elapsed = 0.037 s Time elapsed = 0.033 s Time elapsed = 0.034 s Time elapsed = 0.035 s Time elapsed = 0.035 s Time elapsed = 0.031 s Time elapsed = 0.035 s Time elapsed = 0.031 s Time elapsed = 0.031 s Time elapsed = 0.038 s Time elapsed = 0.034 s Time elapsed = 0.032 s Time elapsed = 0.034 s Time elapsed = 0.032 s Time elapsed = 0.031 s Time elapsed = 0.040 s Time elapsed = 0.038 s Time elapsed = 0.034 s Time elapsed = 0.038 s Time elapsed = 0.031 s Time elapsed = 0.031 s Time elapsed = 0.037 s Time elapsed = 0.034 s Time elapsed = 0.032 s Time elapsed = 0.035 s Time elapsed = 0.034 s Time elapsed = 0.032 s Time elapsed = 0.037 s Time elapsed = 0.042 s Time elapsed = 0.034 s Time elapsed = 0.039 s Time elapsed = 0.031 s Time elapsed = 0.034 s Time elapsed = 0.036 s Time elapsed = 0.031 s Time elapsed = 0.034 s Time elapsed = 0.039 s Time elapsed = 0.031 s Time elapsed = 0.034 s Time elapsed = 0.032 s Time elapsed = 0.035 s Time elapsed = 0.034 s Time elapsed = 0.130 s Time elapsed = 0.432 s Time elapsed = 0.536 s Time elapsed = 0.566 s Time elapsed = 0.575 s Time elapsed = 0.592 s Time elapsed = 0.586 s Time elapsed = 0.586 s Time elapsed = 0.588 s Time elapsed = 0.587 s Time elapsed = 0.587 s Time elapsed = 0.590 s Time elapsed = 0.592 s Time elapsed = 0.588 s Time elapsed = 0.592 s Time elapsed = 0.603 s Time elapsed = 0.606 s Time elapsed = 0.676 s Time elapsed = 0.597 s Time elapsed = 0.592 s Time elapsed = 0.595 s Time elapsed = 0.593 s Time elapsed = 0.602 s Time elapsed = 0.599 s Time elapsed = 0.601 s Time elapsed = 0.598 s Time elapsed = 0.600 s Time elapsed = 0.600 s Time elapsed = 0.596 s Time elapsed = 0.602 s Time elapsed = 0.608 s Time elapsed = 0.602 s Time elapsed = 0.602 s Time elapsed = 0.605 s Time elapsed = 0.598 s Time elapsed = 0.595 s Time elapsed = 0.594 s Time elapsed = 0.599 s Time elapsed = 0.596 s Time elapsed = 0.597 s Time elapsed = 0.599 s Time elapsed = 0.599 s Time elapsed = 0.600 s Time elapsed = 0.606 s Time elapsed = 0.607 s Time elapsed = 0.602 s

mehrhardt commented 6 years ago

This is now the "complete" example which does only depend on odl:

from __future__ import print_function
import numpy as np
import odl

def total_variation(domain, grad=None):
    """ Total variation functional.

    Parameters
    ----------
    domain : odlspace
        domain of TV functional
    grad : gradient operator, optional
        Gradient operator of the total variation functional. This may be any
        linear operator and thereby generalizing TV. default=forward
        differences with Neumann boundary conditions

    Examples
    --------
    Check that the total variation of a constant is zero

    >>> import odl.contrib.spdhg as spdhg, odl
    >>> space = odl.uniform_discr([0, 0], [3, 3], [3, 3])
    >>> tv = spdhg.total_variation(space)
    >>> x = space.one()
    >>> tv(x) < 1e-10
    """

    if grad is None:
        grad = odl.Gradient(domain, method='forward', pad_mode='symmetric')
        grad.norm = 2 * np.sqrt(sum(1 / grad.domain.cell_sides**2))
    else:
        grad = grad

    f = odl.solvers.GroupL1Norm(grad.range, exponent=2)

    return f * grad

class TotalVariationNonNegative(odl.solvers.Functional):
    """ Total variation function with nonnegativity constraint and strongly
    convex relaxation.

    In formulas, this functional may represent

        alpha * |grad x|_1 + char_fun(x) + beta/2 |x|^2_2

    with regularization parameter alpha and strong convexity beta. In addition,
    the nonnegativity constraint is achieved with the characteristic function

        char_fun(x) = 0 if x >= 0 and infty else.

    Parameters
    ----------
    domain : odlspace
        domain of TV functional
    alpha : scalar, optional
        Regularization parameter, positive
    prox_options : dict, optional
        name: string, optional
            name of the method to perform the prox operator, default=FGP
        warmstart: boolean, optional
            Do you want a warm start, i.e. start with the dual variable
            from the last call? default=True
        niter: int, optional
            number of iterations per call, default=5
        p: array, optional
            initial dual variable, default=zeros
    grad : gradient operator, optional
        Gradient operator to be used within the total variation functional.
        default=see TV
    """

    def __init__(self, domain, alpha=1, prox_options={}, grad=None,
                 strong_convexity=0):
        """
        """

        self.strong_convexity = strong_convexity

        if 'name' not in prox_options:
            prox_options['name'] = 'FGP'
        if 'warmstart' not in prox_options:
            prox_options['warmstart'] = True
        if 'niter' not in prox_options:
            prox_options['niter'] = 5
        if 'p' not in prox_options:
            prox_options['p'] = None
        if 'tol' not in prox_options:
            prox_options['tol'] = None

        self.prox_options = prox_options

        self.alpha = alpha
        self.tv = total_variation(domain, grad=grad)
        self.grad = self.tv.right
        self.nn = odl.solvers.IndicatorBox(domain, 0, np.inf)
        self.l2 = 0.5 * odl.solvers.L2NormSquared(domain)

        super().__init__(space=domain, linear=False, grad_lipschitz=0)

    def __call__(self, x):
        """ Characteristic function of the non-negative orthant

        Parameters
        ----------
        x : np.array
            vector / image

        Returns
        -------
        extended float (with infinity)
            Is the input in the non-negative orthant?

        Examples
        --------
        Check that the total variation of a constant is zero

        >>> import odl.contrib.spdhg as spdhg, odl
        >>> space = odl.uniform_discr([0, 0], [3, 3], [3, 3])
        >>> tvnn = spdhg.TotalVariationNonNegative(space, alpha=2)
        >>> x = space.one()
        >>> tvnn(x) < 1e-10

        Check that negative functions are mapped to infty

        >>> import odl.contrib.spdhg as spdhg, odl, numpy as np
        >>> space = odl.uniform_discr([0, 0], [3, 3], [3, 3])
        >>> tvnn = spdhg.TotalVariationNonNegative(space, alpha=2)
        >>> x = -space.one()
        >>> np.isinf(tvnn(x))
        """

        nn = self.nn(x)

        if nn is np.inf:
            return nn
        else:
            out = self.alpha * self.tv(x) + nn
            if self.strong_convexity > 0:
                out += self.strong_convexity * self.l2(x)
            return out

    def proximal(self, sigma):
        """ Prox operator of TV. It allows the proximal step length to be a vector
        of positive elements.

        Parameters
        ----------
        x : np.array
            vector / image

        Returns
        -------
        extended float (with infinity)
            Is the input in the non-negative orthant?

        Examples
        --------
        Check that the proximal operator is the identity for sigma=0

        >>> import odl.contrib.spdhg as spdhg, odl, numpy as np
        >>> space = odl.uniform_discr([0, 0], [3, 3], [3, 3])
        >>> tvnn = spdhg.TotalVariationNonNegative(space, alpha=2)
        >>> x = -space.one()
        >>> y = tvnn.proximal(0)(x)
        >>> (y-x).norm() < 1e-10

        Check that negative functions are mapped to 0

        >>> import odl.contrib.spdhg as spdhg, odl, numpy as np
        >>> space = odl.uniform_discr([0, 0], [3, 3], [3, 3])
        >>> tvnn = spdhg.TotalVariationNonNegative(space, alpha=2)
        >>> x = -space.one()
        >>> y = tvnn.proximal(0.1)(x)
        >>> y.norm() < 1e-10
        """

        if sigma == 0:
            return odl.IdentityOperator(self.domain)

        else:
            def tv_prox(z, out=None):

                if out is None:
                    out = z.space.zero()

                opts = self.prox_options

                sigma_ = np.copy(sigma)
                z_ = z.copy()

                if self.strong_convexity > 0:
                    sigma_ /= (1 + sigma * self.strong_convexity)
                    z_ /= (1 + sigma * self.strong_convexity)

                def proj_C(x, out=None):
                    return self.nn.proximal(1)(x, out=out)

                def proj_P(x, out=None):
                    norm = odl.solvers.GroupL1Norm(self.grad.range, exponent=2)
                    return norm.convex_conj.proximal(0)(x, out=out)

                if opts['name'] == 'FGP':
                    if opts['warmstart']:
                        if opts['p'] is None:
                            opts['p'] = self.grad.range.zero()

                        p = opts['p']
                    else:
                        p = self.grad.range.zero()

                    sigma_sqrt = np.sqrt(sigma_)

                    z_ /= sigma_sqrt
                    grad = sigma_sqrt * self.grad
                    grad.norm = sigma_sqrt * self.grad.norm
                    niter = opts['niter']
                    alpha = self.alpha
                    out[:] = fgp_dual(p, z_, alpha, niter, grad, proj_C,
                                      proj_P, tol=opts['tol'])

                    out *= sigma_sqrt

                    return out

                else:
                    raise NotImplementedError('Not yet implemented')

            return tv_prox

def fgp_dual(p, data, alpha, n_iter, grad, proj_C, proj_P, tol=None, **kwargs):
    """ Computes a solution to the ROF problem with the fast gradient
    projection algorithm.

    Parameters
    ----------
    p : np.array
        dual initial variable
    data : np.array
        noisy data / proximal point
    alpha : float
        regularization parameter
    n_iter : int
        number of iterations
    grad : gradient class
        class that supports grad(x), grad.adjoint(x), grad.norm
    proj_C : function
        projection onto the constraint set of the primal variable,
        e.g. non-negativity
    proj_P : function
        projection onto the constraint set of the dual variable,
        e.g. norm <= 1
    tol : float (optional)
        nonnegative parameter that gives the tolerance for convergence. If set
        None, then the algorithm will run for a fixed number of iterations

    Other Parameters
    ----------------
    callback : callable, optional
        Function called with the current iterate after each iteration.
    """

    # Callback object
    callback = kwargs.pop('callback', None)
    if callback is not None and not callable(callback):
        raise TypeError('`callback` {} is not callable'.format(callback))

    factr = 1 / (grad.norm**2 * alpha)

    q = p.copy()
    x = data.space.zero()

    t = 1.

    if tol is None:
        def convergence_eval(p1, p2):
            return False
    else:
        def convergence_eval(p1, p2):
            return (p1 - p2).norm() / p1.norm() < tol

    pnew = p.copy()

    if callback is not None:
        callback(p)

    for k in range(n_iter):
        t0 = t
        grad.adjoint(q, out=x)
        proj_C(data - alpha * x, out=x)
        grad(x, out=pnew)
        pnew *= factr
        pnew += q
        proj_P(pnew, out=pnew)

        converged = convergence_eval(p, pnew)

        if not converged:
            # update step size
            t = (1 + np.sqrt(1 + 4*t0**2))/2.

            # calculate next iterate
            q[:] = pnew + (t0 - 1)/t * (pnew - p)

        p[:] = pnew

        if converged:
            t = None
            break

        if callback is not None:
            callback(p)

    # get current image estimate
    x = proj_C(data - alpha * grad.adjoint(p))

    return x

import odl

niter_target = 100
#n = 10   # works perfect!
n = 100
X = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20], shape=[n, n])

geometry = odl.tomo.parallel_beam_geometry(X, num_angles=20, det_shape=20)
G = odl.tomo.RayTransform(X, geometry, impl='astra_cpu')
#G = odl.IdentityOperator(X) # works perfect!
#
Y = G.range
groundtruth = X.one()

#sinogram = 1 * Y.one() # works perfect!
sinogram = 100 * Y.one()
#factors = 1 * Y.one() # works perfect!
factors = 34 * Y.one()

background = 10 * Y.one()
data = factors * sinogram
#f = odl.solvers.L2NormSquared(Y).translated(-background) # works perfect!
f = odl.solvers.KullbackLeibler(Y, data).translated(-background)
A = factors * G

#g = odl.solvers.functional.ConstantFunctional(X, 0) # works perfect!
#g = odl.solvers.L1Norm(X) # works perfect!
g = TotalVariationNonNegative(X, alpha=1e-1)
g.prox_options['p'] = None

callback = odl.solvers.CallbackPrintTiming(step=1, cumulative=False)

odl.solvers.pdhg(X.zero(), f, g, A, 1, 1, niter_target, callback=callback)
kohr-h commented 6 years ago

I did some profiling with the code. The hot spot is the power ufunc in the pointwise norm that is called in the GroupL1Norm convex conjugate proximal. I tracked the slowdown all the way down to the call into Numpy, so there seems to be no weird thing going on in ODL. What I did notice is that at the point where the slowdown happens, the numbers in the power call become quite huge, like from 1e+4 to 1e+9, so I assume that the power function simply slows down with such large numbers. So it seems to me that the method is diverging and the numerics simply have a hard time with the large numbers.

mehrhardt commented 6 years ago

How did you profile this? It would be good if I can do this next time myself.

Do you have a small minimal example where you observe this? I tried this myself but I could not reproduce the problem. Perhaps I misunderstood?

import numpy as np
from timeit import default_timer as timer

x = 2 * np.ones((1000000, 1))

def nppower(x):
    np.power(x, 2, out=x)

for k in range(100):
    start = timer()
    nppower(x)
    end = timer()
    print('number: {}, time: {}'.format(x[0], end - start))
mehrhardt commented 6 years ago

Also, which quantities get this large? In my original example (not the "simplified" above), all images look good and no divergence can be observed.

kohr-h commented 6 years ago

How did you profile this? It would be good if I can do this next time myself.

I used the line_profiler module and inserted @profile decorators wherever I needed them (you get nice line-by-line timings, it's described here). Then I studied the behavior differences between 50 iterations (no slowdown) and 100 iterations (slowdown). The first thing I did is decorate all the functions in your example, and I observed that the projP is what gets slower. From there on I worked my way down, following the trail of slowdown, roughly along this path: GroupL1Norm.convex_conj.proximal -> PointwiseNorm -> ufuncs.power -> NumpyTensor.__array_ufunc__ -> call into the Numpy ufunc. So that lead me to the conclusion that the Numpy code was getting slower for some reason (I also didn't see any other slowdown).

To get to the root cause, I simply went into the NumpyTensor.__array_ufunc__ method and inserted print(ufunc, inputs, kwargs) just before this line. That of course produced a huge amount of output, but if you set the iteration number to 58-ish the amount of scrolling back in the console stays manageable. I could observe this jump from 1e+4 to 1e+9 values between 2 or 3 iterations, exactly where things get slow. Probably you can do it in a simpler way by adding a callback that prints out the max of the current iterate.

kohr-h commented 6 years ago

Do you have a small minimal example where you observe this? I tried this myself but I could not reproduce the problem. Perhaps I misunderstood?

The exponent in the slow code is 0.5, that may make a difference. Also try bigger numbers.

mehrhardt commented 6 years ago

I still don't understand what is going wrong. At least I have one example outside of ODL that shows this behaviour. On the positive side, one can replace power(0.5) by sqrt which does not show this bad performance. How difficult would it be to change this in ODL?

import numpy as np
from timeit import default_timer as timer

x = 2 * np.ones((1000000, 1))
x1 = 2 * np.ones((1000000, 1))

def nppower(x):
    np.power(x, .5, out=x)

def npsqrt(x):
    np.sqrt(x, out=x)

for k in range(1000):
    start = timer()
    nppower(x)
    end = timer()
    print('number: {}, time: {}'.format(x[0], end - start))
    start = timer()
    npsqrt(x1)
    end = timer()
    print('sqrtnumber: {}, time: {}'.format(x[0], end - start))
adler-j commented 6 years ago

Interestingly, there is a numpy issue about this here, which I have apparently already commented on. Sadly it is 2 years old with no response.

The solution to this would be to special case p=2 in this function:

https://github.com/odlgroup/odl/blob/0ab389fe81c35fdc9617ec60e156c08ebda6db2f/odl/operator/tensor_ops.py#L266-L287

doing so should be really easy.

kohr-h commented 6 years ago

@adler-j was too quick to answer :-) Back when I implemented this I also assumed that np.power had a special case for certain exponents, but it's obviously not the case. So the best way is to make another method like the existing ones for the special case 2.

Edit: To clarify, I'd prefer a _call_vecfield_2 additional method instead of changing the general p one.

adler-j commented 6 years ago

I fully agree with @kohr-h on the proposed solution.

mehrhardt commented 6 years ago

If this issue is still open, would any of you volunteer to implement this quick fix?

adler-j commented 6 years ago

I'll see If i can get this done

mehrhardt commented 6 years ago

many thanks! greatly appreciated