JEFworks-Lab / STalign

Python tool for alignment of spatial transcriptomics (ST) data using diffeomorphic metric mapping
https://jef.works/STalign/
GNU General Public License v3.0
61 stars 10 forks source link

GPU bug? #32

Closed andrewjkwok closed 6 days ago

andrewjkwok commented 6 days ago

Hi,

Thanks for this great tool.

I have been able to successfully run STalign on my cpu, but am coming up against an error when trying to use my NVIDIA GPU. Specifically, when I run:

transform = STalign.LDDMM_3D_to_slice(
    xI,I,xJ,J, 
    T=T,L=L,
    nt=4,niter=2000,
    device=device,
    sigmaA = sigmaA, #standard deviation of artifact intensities
    sigmaB = sigmaB, #standard deviation of background intensities
    sigmaM = sigmaM, #standard deviation of matching tissue intenities
    muA = muA, #average of artifact intensities
    muB = muB #average of background intensities
)

where device = cuda:0

I receive the following error:


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed exec>:2

File ~/Projects/MERFISH_test_2023/software/test_run_python3.10/lib/python3.10/site-packages/STalign/STalign.py:1554, in LDDMM_3D_to_slice(xI, I, xJ, J, pointsI, pointsJ, L, T, A, v, xv, a, p, expand, nt, niter, diffeo_start, epL, epT, epV, sigmaM, sigmaB, sigmaA, sigmaR, sigmaP, device, dtype, muA, muB)
   1552 # now the E step, update the weights
   1553 WM = pi[0]* torch.exp( -torch.sum((fAI - J)**2,0)/2.0/sigmaM**2 )/np.sqrt(2.0*np.pi*sigmaM**2)**J.shape[0]
-> 1554 WA = pi[1]* torch.exp( -torch.sum((muA[...,None,None,None] - J)**2,0)/2.0/sigmaA**2 )/np.sqrt(2.0*np.pi*sigmaA**2)**J.shape[0]
   1555 WB = pi[2]* torch.exp( -torch.sum((muB[...,None,None,None] - J)**2,0)/2.0/sigmaB**2 )/np.sqrt(2.0*np.pi*sigmaB**2)**J.shape[0]
   1556 WS = WM+WB+WA

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Given that the code runs fine when my device argument is set to cpu, could I check what the issue may be?

andrewjkwok commented 6 days ago

That was very silly! The code works fine.

I was following the tutorial here: https://jef.works/STalign/notebooks/merfish-allen3Datlas-alignment.html

The simple error was that in a previous step setting muA and muB, I had set the device to cpu

sigmaA = 2 #standard deviation of artifact intensities
sigmaB = 2 #standard deviation of background intensities
sigmaM = 2 #standard deviation of matching tissue intenities
muA = torch.tensor([3,3,3],device=device) #average of artifact intensities
muB = torch.tensor([0,0,0],device=device) #average of background intensities

It might be worth just editing this in the tutorial to avoid similar issues in the future?