BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
321 stars 58 forks source link

Incompatible Tensor broadcasts with MPS accelerator - Mac M2 #221

Open NicolasSompairac opened 1 year ago

NicolasSompairac commented 1 year ago

Dear devs,

I have been trying to run cell2location's tutorial for mapping human lymph node. I decided to try to run it on my Mac M2 but arriving at the model training step, the kernel always crashes even before starting the training (no tqdm progress bar). The error shown in the jupyter terminal is the following:

loc("mps_multiply"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/a0876c02-1788-11ed-b9c4-96898e02b808/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":219:0)): error: input types 'tensor<34x10237xf32>' and 'tensor<1xi64>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).

Is this a problem with the tensor initialisation from cell2location or a more general issue with MPS?

This definitelly seems to come from GPU handling as the training runs fine if use_gpu is set to False.

adamgayoso commented 1 year ago

Not all tensor operations have been implemented with MPS, so I wouldn't expect this to work for a while. Even if it did work, I would not expect any sort of massive speed up compared to a discrete NVIDIA gpu (though would like to be wrong :))

NicolasSompairac commented 1 year ago

@adamgayoso Well, the thing is, I am not looking for a "faster" way of running it using MPS, I'm looking for a way to simply run it, since NVIDIA GPU is out of the question of M2 Macs... As of now, it seems only possible to run the tool using CPU on a Mac and you can guess how slower this is.

And in this case, the error doesn't seem to come from a lack of implementation since I already saw such errors before and it's stated quite clearly when it happens. Here it looks more like the units (f32 vs i64) of the tensor broadcast are set in a way that is incompatible with the default ones. So if there is a way to specify that somewhere, I would love to get some tips on how to achieve that :)

joe-jhou2 commented 1 year ago

I'm so bothered by this problem for a while.

vitkl commented 1 year ago

'tensor<34x10237xf32>' and 'tensor<1xi64>'

This could suggest that the input variable type is incorrectly recognised somehow. I would make sure that priors are provided as floats and see what happens:

mod = cell2location.models.Cell2location(
    N_cells_per_location=30.0,
    detection_alpha=20.0,
)

or

mod = cell2location.models.Cell2location(
    N_cells_per_location=float(30),
    detection_alpha=float(20),
)

Which data is this? If the lymph node tutorial data then, I struggle to understand where the tensors with these shapes occur. Still, I would try inputting all priors as floats.

adamgayoso commented 1 year ago

Well, the thing is, I am not looking for a "faster" way of running it using MPS, I'm looking for a way to simply run it, since NVIDIA GPU is out of the question of M2 Macs... As of now, it seems only possible to run the tool using CPU on a Mac and you can guess how slower this is.

And in this case, the error doesn't seem to come from a lack of implementation since I already saw such errors before and it's stated quite clearly when it happens.

however, here you are only seeing one error, there are still many things that are important and not yet implemented on MPS. For example,

File ~/Software/scvi-tools/scvi/distributions/_negative_binomial.py:103, in log_nb_positive(x, mu, theta, eps, log_fn, lgamma_fn)
     98 lgamma = lgamma_fn
     99 log_theta_mu_eps = log(theta + mu + eps)
    100 res = (
    101     theta * (log(theta + eps) - log_theta_mu_eps)
    102     + x * (log(mu + eps) - log_theta_mu_eps)
--> 103     + lgamma(x + theta)
    104     - lgamma(theta)
    105     - lgamma(x + 1)
    106 )
    108 return res

NotImplementedError: The operator 'aten::lgamma.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

And this is with the nightly pytorch. Thus, I would not expect this to work for many months.