JEFworks-Lab / STalign

Python tool for alignment of spatial transcriptomics (ST) data using diffeomorphic metric mapping
https://jef.works/STalign/
GNU General Public License v3.0
67 stars 14 forks source link

How to choose parameters? #19

Open grst opened 1 year ago

grst commented 1 year ago

In the tutorial, different values for the different sigmas are chosen for the different examples. No rationale is given about the choice of these parameters, except for sigmaM, where the API documentation says:

As an common example (rule of thumb), you could chose this parameter to be the variance of the pixels in your target image.

Do you have any recommendations how I can find the appropriate parameters for my data?


More concretely, I am trying to align an IHC with an ISH image. For testing, I'm focusing on this region if the image, converted to grayscale: image

With default parameters, the model doesn't converge, so I suspect that I need to tune the parameters, but I'm unsure where to start.

Here are the diagnostic plots generated:

image image image

kpclifton commented 1 year ago

Hi Gregor,

Thank you for your question about guidance for choosing appropriate parameters for your data. Also it is appreciated that you have included your images and diagnostics plots.

Based on the Weights plot, I believe that for your case the default sigmaM value is too high to effectively utilize the Gaussian Mixture Model component of STalign, which labels pixels as background, matching, and artifact tissue based on pixel intensity. If the model is tuned correctly, then in the Weights plot background will be blue, tissue to be matched will be red, and artifacts will be green.

For the sigmaM, sigmaA, and sigmaB values, I would suggest to tune these values to be on the scale of the standard deviation of the pixel intensity of the target image.

To do this, I usually plot a histogram of the pixel intensity values from which I can estimate the sigmas as half the width of the peaks that correspond to background, matching, and artifact tissue. 
 In the example below, the background is the peak around 0, the matching tissue is the peak at 0.6, and if there are artifacts they would likely be around 1, and so I would begin by estimating the values as sigmaB = 0.05, sigmaM = 0.3, and sigmaA = 0.05. From my experience, it is not that important that these be precise, but they should at least be the correct order of magnitude for your intensity values.

Intensity Histogram

Additionally, based on your resulting smoothing kernel plot, I would suggest next to begin modifying the spatial smoothness constant a. This value has units of length and therefore, it should be tuned relative to the size of the images of interest. Your images seem smaller than the examples in the tutorials so I would suggest to reduce a until you get a smoothing kernel plot that has the color variation like the example below. As you tune a it is worth keeping in mind that small values of a may be overfitting noise and large values of a may lead to low accuracy.

smoothing kernel

Hopefully this gives you some direction to begin exploring the parameter space.

Best, Kalen

grst commented 1 year ago

Hi Kalen,

thanks for your detailed response! With this information, I could make it somewhat work but the result is still not satisfying. Is there anything else you'd suggest to optimize? I also have been wondering about how to choose the learning rate epV. The default value is 2000, however in the tutorial, you selected 10.


I now switched from the b/w image to RGB to get closer to my real use-case: image

I chose the first as "source" because it's the more "complete" image (in the right image there's missing a section in the bottom left)

fig, ax = plt.subplots()
ax.hist(J.ravel(), bins=50)

image

Clipping away the peak at the right, I very roughly determined the standard deviation of the "signal peak" as follows:

>>> np.std(J.ravel()[J.ravel() < 0.8])
0.1475

Using these parameters

out = STalign.LDDMM(
    [YI, XI],
    I,
    [YJ, XJ],
    J,
    device="cuda:0",
    niter=5000,
    **{
        "sigmaM": 0.15,  # standard deviation of signal peak (tried also 0.3)
        "sigmaB": 0.01,  # standard deviation of background peak 
        "sigmaA": 0.01,  # standard deviation of artifact peak
        "epV": 10,
        "a": 50,
        "muB": torch.tensor([215, 215, 215]) / 255.0,  # "grey" background determined with eyedroper
        "muA": torch.tensor([0, 0, 0]),  # use black as artifact (?)
    },
)

I arrive at the following result: image image image