mikgroup / sigpy

Python package for signal processing, with emphasis on iterative methods
BSD 3-Clause "New" or "Revised" License
307 stars 93 forks source link

sigpy.mri.util.tseg_off_res_b_ct() for a 3D image and 3D coordinates #139

Open joeyplum opened 7 months ago

joeyplum commented 7 months ago

I am trying to create the B and Ct matrices needed for time-segmented off-resonance compensation---specifically for 3-dimensional B0 arrays and 3D coordinates (shape: N_excitations, N_samples, N_dimensions).

I can get the function: sigpy.mri.util.tseg_off_res_b_ct()to work for a 2-dimensional B0 array, but I cannot use the current code for a 3-dimensional B0 array. My guess is that the np.concatenate commands are causing the problem.

To fix this, I created my own function and replaced the np.concatenate() commands to np.ravel().

Please feel free to comment if you have also seen this issue before.

joeyplum commented 6 months ago

If interested, this was my updated function. Note, I have also added T2star maps as an input, as they can be treated using similar methods.

def tseg_b_ct(F, b0, t2star, bins, lseg, readout, plot=False):
    """Creates B and Ct matrices needed for time-segmented compensation.

    Args:
        F (linop): Fourier encoding linear operator (e.g. NUFFT).
        b0 (array): inhomogeneity matrix of frequency offsets (Hz).
        t2star (array): inhomogeneity matrix of t2star (s).
        bins (int): number of histogram bins to use.
        lseg (int): number of time segments.
        readout (float): length of readout pulse (s).
        plot (Bool): plot basis.

    Returns:
        2-element tuple containing

        - **B** (*array*): temporal interpolator.
        - **Ct** (*array*): off-resonance phase at each time segment center.
    """

    # create time vector
    N_samp = F.oshape[1]
    t = np.linspace(0, readout, N_samp)
    hist_wt_b0, bin_edges_b0 = np.histogram(
        np.imag(2j * np.pi * np.ravel(b0)), bins
    )
    # The minus sign is dealt with in the ct and b lines, near end of code
    hist_wt_t2star, bin_edges_t2star = np.histogram(
        np.ravel(1/t2star), bins
    )

    # Build B and Ct
    bin_centers_b0 = bin_edges_b0[1:] - bin_edges_b0[1] / 2
    bin_centers_t2star = bin_edges_t2star[1:] - bin_edges_t2star[1] / 2
    # Get total number of counts falling into each bin
    hist_wt = hist_wt_t2star + hist_wt_b0
    zk = bin_centers_t2star + 1j * bin_centers_b0
    tl = np.linspace(0, lseg, lseg) / lseg * \
        readout   # time seg centers
    # calculate off-resonance phase/t2star decay @ each time seg, for hist bins
    ch = np.exp(-np.expand_dims(tl, axis=1) @ np.expand_dims(zk, axis=0))
    w = np.diag(np.sqrt(hist_wt))
    p = np.linalg.pinv(w @ np.transpose(ch)) @ w
    b = p @ np.exp(
        -np.expand_dims(zk, axis=1) @ np.expand_dims(t, axis=0)
    )
    b = np.transpose(b)
    b0_v = np.expand_dims((2j * np.pi * np.ravel(b0)) +
                          np.ravel(1/t2star), axis=0)
    ct = np.transpose(np.exp(-np.expand_dims(tl, axis=1) @ b0_v))

    # Plot
    if plot:
        fig, (ax1, ax2) = plt.subplots(
            2, sharex=False, figsize=(6, 3), dpi=100)
        ax1.plot(np.real(b[:, :]), color='g')
        ax1.plot(np.imag(b[:, :]), color='r')
        ax1.set_ylabel('b')
        ax1.set_xlabel("sample number")
        ax2.plot(np.real(ct[:, 0].reshape(F.ishape)).ravel(), color='c')
        ax2.plot(np.imag(ct[:, 0].reshape(F.ishape)).ravel(), color='m')
        ax2.set_ylabel('ct')
        ax2.set_xlabel("sample number")
        plt.show()

    for ii in range(lseg):
        Bi = sp.linop.Multiply(F.oshape, b[:, ii])
        Cti = sp.linop.Multiply(F.ishape, ct[:, ii].reshape(F.ishape))

        # Effectively, calculate A = F + Bi * F(Cti)
        if ii == 0:
            A = Bi * F * Cti
        else:
            A = A + Bi * F * Cti

    return A