spacetx / starfish

starfish: unified pipelines for image-based transcriptomics
https://spacetx-starfish.readthedocs.io/en/latest/
MIT License
221 stars 67 forks source link

Possible parameter inversion in LearnTransform.Translation class #2005

Open guoyang-github opened 8 months ago

guoyang-github commented 8 months ago

In the LearnTransform.Translation class of the newest starfish version, there's a segment of code where I think the parameters for 'reference_image' and 'moving_image' might have been swapped:

shift, error, phasediff = phase_cross_correlation(
reference_image=target_image.data,
moving_image=reference_image.data,
upsample_factor=self.upsampling )

I suspect that based on the function's definition and typical usage patterns, the correct code should be:

shift, error, phasediff = phase_cross_correlation(
reference_image=reference_image.data,
moving_image=target_image.data,
upsample_factor=self.upsampling )

Thanks for your attention to this matter.

berl commented 8 months ago

hi @guoyang-github thanks for taking a look at this. For this (and your other helpful bug report) can you provide a permalink to the code you're describing? And in this case, can you show an example case where this produces unanticipated behavior within starfish examples or use-cases?

guoyang-github commented 8 months ago

hi berl, thanks for your reply. The permalink to the code: https://github.com/spacetx/starfish/blob/master/starfish/core/image/_registration/LearnTransform/translation.py Let's take a code example:

import starfish
from starfish.image import ApplyTransform, Filter, LearnTransform, Segment
from starfish.types import Axes, FunctionSource
exp = starfish.Experiment.from_json("https://d2nhj9g34unfro.cloudfront.net/browse/formatted/20180926/iss_breast/experiment.json")

fov = exp['fov_000']
primary_image = fov.get_image(starfish.FieldOfView.PRIMARY_IMAGES).reduce({Axes.CH, Axes.ZPLANE}, func="max")
dots_image = fov.get_image('dots')

learn_translation = LearnTransform.Translation(reference_stack = dots_image, axes=Axes.ROUND, upsampling=100)
transforms_list = learn_translation.run(primary_image, verbose = True)

For r: 0, Shift: [ 5.76 -22.81], Error: 0.6056388649515209 For r: 1, Shift: [ 1.86 -22.16], Error: 0.6367408899871313 For r: 2, Shift: [ -3.18 -21.87], Error: 0.6762644068856953 For r: 3, Shift: [ -4.34 -14.9 ], Error: 0.6818544147316385 During the image registration process, I think we should take dots_image as reference stack, then move each round of primary_image to align to dots_image. So, in function phase_cross_correlation, the parameter moving_image should be target_image (primary_image) and the parameter reference_image should be dots_image. Am I understanding this correctly? If I change the order of the two parameters, results will be: For r: 0, Shift: [-5.76 22.81], Error: 0.6056388648526686 For r: 1, Shift: [-1.86 22.16], Error: 0.6367408898832333 For r: 2, Shift: [ 3.18 21.87], Error: 0.6762644066836463 For r: 3, Shift: [ 4.34 14.9 ], Error: 0.6818544147341965

Here's another problem about the order of X, Y. According to the function's definition, Axis ordering should be consistent with the axis order of the input array for 'Shift'. So, Y = -4.34, X = -14.9 for r:3 as example. However, when I print transforms_list

print(transforms_list)

tile indices: {<Axes.ROUND: 'r'>: 0} translation: y=22.81, x=-5.76, rotation: 0.0, scale: 1.0 tile indices: {<Axes.ROUND: 'r'>: 1} translation: y=22.16, x=-1.8599999999999999, rotation: 0.0, scale: 1.0 tile indices: {<Axes.ROUND: 'r'>: 2} translation: y=21.87, x=3.18, rotation: 0.0, scale: 1.0 tile indices: {<Axes.ROUND: 'r'>: 3} translation: y=14.9, x=4.34, rotation: 0.0, scale: 1.0

I obtained the swapped X, Y values. Refer to the code: https://github.com/spacetx/starfish/blob/master/starfish/core/image/_registration/transforms_list.py

    def __repr__(self) -> str:
        translation_strings = [
            f"tile indices: {t[0]}\ntranslation: y={t[2].translation[0]}, "
            f"x={t[2].translation[1]}, rotation: {t[2].rotation}, scale: {t[2].scale}"
            for t in self.transforms
        ]
        return "\n".join(translation_strings)

Should the code be adjusted to:

    def __repr__(self) -> str:
        translation_strings = [
            f"tile indices: {t[0]}\ntranslation: x={t[2].translation[0]}, "
            f"y={t[2].translation[1]}, rotation: {t[2].rotation}, scale: {t[2].scale}"
            for t in self.transforms
        ]
        return "\n".join(translation_strings)

Thanks.