SyneRBI / PETRIC

PET Image Reconstruction Challenge 2024
https://www.ccpsynerbi.ac.uk/events/petric/
5 stars 3 forks source link

use mode="staggered" in partitioner of main_BSREM.py for consistency #70

Closed gschramm closed 1 month ago

gschramm commented 1 month ago

Update1: The difference disappears if I use 1 instead of 7 subsets, so I guess the issue is related to the definition of the subsets. What is the easiest way to see which views are involved in a subset acquisition model / subset obj_func?

Update2: The issue is related to the mode of the partitioner. If set to staggered the difference also vanished for more than 1 subsets. Might be better to change the partioner mode in the main_BRSEM.py here. In main_ISTA.py it is already uses staggered.

Hi all,

to test if setup the (data fidelity) subset objective functions correctly, I was comparing the multiplicative OSEM update from `main_OSEM.py

$$ x^+ = \frac{x}{A^T 1} A^T \frac{y}{Ax + s} $$

` with my additive OSEM update using the subset objective function gradient (see below).

$$ x^+ = x + \frac{x}{A^T 1} \nabla_x logL(x) $$

Right now, I don't understand why the two give different updates. The image below shows:

test

class Submission(Algorithm):
    """
    OSEM algorithm example.
    NB: In OSEM, the multiplicative term cancels in the back-projection of the quotient of measured & estimated data
    (so this is used here for efficiency). Note that a similar optimisation can be used for all algorithms using the Poisson log-likelihood.
    NB: OSEM does not use `data.prior` and thus does not converge to the MAP reference used in PETRIC.
    NB: this example does not use the `sirf.STIR` Poisson objective function.
    NB: see https://github.com/SyneRBI/SIRF-Contribs/tree/master/src/Python/sirf/contrib/BSREM
    """

    def __init__(
        self,
        data: Dataset,
        num_subsets: int = 7,
        update_objective_interval: int = 10,
        **kwargs
    ):
        """
        Initialisation function, setting up data & (hyper)parameters.
        NB: in practice, `num_subsets` should likely be determined from the data.
        This is just an example. Try to modify and improve it!
        """

        self.subset = 0
        self.x = data.OSEM_image.clone()

        #############################################################################
        #############################################################################
        #############################################################################
        #############################################################################

        self._data_sub, self._acq_models, self._obj_funs = partitioner.data_partition(
            data.acquired_data,
            data.additive_term,
            data.mult_factors,
            num_subsets,
            initial_image=data.OSEM_image,
        )
        # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
        #data.prior.set_penalisation_factor(
        #    data.prior.get_penalisation_factor() / num_subsets
        #)
        #data.prior.set_up(data.OSEM_image)

        # for f in self._obj_funs:  # add prior evenly to every objective function
        #    f.set_prior(data.prior)

        #############################################################################
        #############################################################################
        #############################################################################
        #############################################################################

        self._acquisition_models = []
        self._prompts = []
        self._sensitivities = []

        # find views in each subset
        # (note that SIRF can currently only do subsets over views)
        views = data.mult_factors.dimensions()[2]
        partitions_idxs = partition_indices(
            num_subsets, list(range(views)), stagger=True
        )

        # for each subset: find data, create acq_model, and create subset_sensitivity (backproj of 1)
        for i in range(num_subsets):
            prompts_subset = data.acquired_data.get_subset(partitions_idxs[i])
            additive_term_subset = data.additive_term.get_subset(partitions_idxs[i])
            multiplicative_factors_subset = data.mult_factors.get_subset(
                partitions_idxs[i]
            )

            acquisition_model_subset = STIR.AcquisitionModelUsingParallelproj()
            acquisition_model_subset.set_additive_term(additive_term_subset)
            acquisition_model_subset.set_up(prompts_subset, self.x)

            subset_sensitivity = acquisition_model_subset.backward(
                multiplicative_factors_subset
            )
            # add a small number to avoid NaN in division
            subset_sensitivity += subset_sensitivity.max() * 1e-6

            self._acquisition_models.append(acquisition_model_subset)
            self._prompts.append(prompts_subset)
            self._sensitivities.append(subset_sensitivity)

        super().__init__(update_objective_interval=update_objective_interval, **kwargs)
        self.configured = True  # required by Algorithm

 def update(self):
        x_cur = self.x

        denom = self._acquisition_models[self.subset].forward(x_cur) + 1e-4
        # divide measured data by estimate (ignoring mult_factors!)
        quotient = self._prompts[self.subset] / denom

        # mult. OSEM update
        x1 = x_cur * (
            self._acquisition_models[self.subset].backward(quotient)
            / self._sensitivities[self.subset]
        )

        # additive OSEM update using gradient of subset objective function
        x2 = x_cur + (x_cur / self._sensitivities[self.subset]) * self._obj_funs[
            self.subset
        ].gradient(x_cur)

        d = x2.as_array() - x1.as_array()
        sl = x1.shape[0] // 2
        vmax = x1.as_array()[sl, :, :].max()

        import matplotlib.pyplot as plt

        fig, ax = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
        ax[0].imshow(x1.as_array()[sl, :, :], vmin=0, vmax=vmax, cmap="Greys")
        ax[1].imshow(x2.as_array()[sl, :, :], vmin=0, vmax=vmax, cmap="Greys")
        ax[2].imshow(d[sl, :, :], vmin=-0.05 * vmax, vmax=0.05 * vmax, cmap="bwr")
        fig.show()
        fig.savefig("test.png")
KrisThielemans commented 1 month ago

Thanks @gschramm. The partitioner code is somewhat experimental, and participants would likely have to modify it for best performance. staggered is definitely better than sequential, but I think that even then, the order of the subsets isn't great (i.e. just "sequential" really as well. The Herman-Meyer order would make much more sense for an ordered subset algorithm. I don't this is in the SIRF-Contrib version, but there's one in https://github.com/TomographicImaging/CIL/blob/3a164f1ff685d7e3689c123021361243511b66fd/Wrappers/Python/cil/optimisation/utilities/sampler.py#L551

gschramm commented 1 month ago

Thanks Kris!