mikgroup / sigpy

Python package for signal processing, with emphasis on iterative methods
BSD 3-Clause "New" or "Revised" License
307 stars 93 forks source link

Creating an app with multi-regularizers (wavelet and total variation) #109

Closed joeyplum closed 2 years ago

joeyplum commented 2 years ago

Hi SigPy group,

First, thank you for creating this powerful software - I've found it easy to use and very helpful for understanding CS in MRI.

I've been working on a problem for a few weeks now: I want to build an app that is capable of solving a multi-regularizer problem. Specifically, I want to solve the following optimization problem:

\min_x \frac{1}{2} \| A x - y \|_2^2 + \lambda_{W} \| W x \|_1 + \lambda_{TV} \| G x \|_1

where lambdaW and lambdaTV are penalties for the wavelet and total variation components, respectively. This optimization problem can also be found in Miki Lustig's 2007 paper.

I came here to see if you have solved this issue in SigPy before, or if you have any advice? So far, I have tried reformulating the problem to use temporary variables for x in each of the regularizer components. However, I haven't been able to get my head around getting to the final solution (i.e. how to recombine the temporary variables). Perhaps it is my inexperience with Python, but I can't quite figure out where to go next.

Thanks for your help, and I look forward to hearing back from you! Joey

Copied below is the most recent version of an app I tried to create using the current apps as a template: (please be critical)

class TVWaveletRecon(sp.app.LinearLeastSquares):
    r"""L1 Wavelet and total variation regularized reconstruction.

    Wavelet is good at preserving edges and low contrast information while TV 
    is efficient at suppressing noise and streaking artifacts.

    Considers the problem

    .. math::
        \min_x \frac{1}{2} \| A x - y \|_2^2 + \lambdaW \| W x \|_1 + \lambdaTV \| G x \|_1

    where A is the sampling operator,
    W is the wavelet operator,
    x is the image, and y is the k-space measurements.

    Args:
        y (array): k-space measurements.
        mps (array): sensitivity maps.
        lamdaW (float): regularization parameter for the wavelet component.
        lamdaTV (float): regularization parameter for the finite difference component.
        weights (float or array): weights for data consistency.
        coord (None or array): coordinates.
        wave_name (str): wavelet name.
        device (Device): device to perform reconstruction.
        coil_batch_size (int): batch size to process coils.
        Only affects memory usage.
        comm (Communicator): communicator for distributed computing.
        **kwargs: Other optional arguments.

    References:
        Lustig, M., Donoho, D., & Pauly, J. M. (2007).
        Sparse MRI: The application of compressed sensing for rapid MR imaging.
        Magnetic Resonance in Medicine, 58(6), 1082-1195.

        Zangen, Z., Khan, W., Babyn, P., Cooper, D., Pratt, I., Carter, Y. (2013)
        Improved Compressed Sensing-Based Algorithm for Sparse-View CT Image Reconstruction.
        Computational and Mathematical Methods in Medicine.
        10.1155/2013/185750

    """

    def __init__(self, y, mps, lamdaW, lamdaTV,
                 weights=None, coord=None,
                 wave_name='db4', device=sp.cpu_device,
                 coil_batch_size=None, comm=None, show_pbar=True,
                 transp_nufft=False, **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps, coord=coord, weights=weights,
                        comm=comm, coil_batch_size=coil_batch_size,
                        transp_nufft=transp_nufft)
        img_shape = mps.shape[1:]

        # Wavelet
        W = sp.linop.Wavelet(img_shape, wave_name=wave_name)
        # Finite difference
        G = sp.linop.FiniteDifference(A.ishape)

        proxg1 = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, lamdaW), W)
        proxg2 = sp.prox.L1Reg(G.oshape, lamdaTV)

        def g(input):
            device = sp.get_device(input)
            xp = device.xp
            with device:
                return lamdaW * xp.sum(xp.abs(W(input))).item() + lamdaTV * xp.sum(xp.abs(input)).item()
        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        # Call super().__init(...) to call the __init(...) of the parent class,
        # sp.app.LinearLeastSquares
        super().__init__(A, y, proxg=proxg1, g=g, show_pbar=show_pbar, **kwargs)

        def h(input):
            device = sp.get_device(input)
            xp = device.xp
            with device:
                return lamdaW * xp.sum(xp.abs(W(input))).item() + lamdaTV * xp.sum(xp.abs(input)).item()
        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        super().__init__(A, y, proxg=proxg2, g=h, G=G, show_pbar=show_pbar, **kwargs)
sidward commented 2 years ago

Hi Joey,

I do not think I will have the time to write the method. Once you understand the math, it's not too difficult to implement this using prox functions. Please see Equation 5.3 of the following: https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf

If I have time, I will re-open this request and try to address it. I am happy to help clarify any questions as well.

joeyplum commented 2 years ago

Thanks for your reply and advice, @sidward! I'll take a look at this paper and try to work out the methods with prox functions. I'll keep you updated when/if I get to a solution.