gao-lab / Cell_BLAST

A BLAST-like toolkit for large-scale scRNA-seq data querying and annotation.
http://cblast.gao-lab.org
MIT License
82 stars 13 forks source link

Possible wasserstein_distance solution #25

Closed yujcccc closed 1 month ago

yujcccc commented 1 year ago

Hi Cellblast Team I'm not a python expert I have a problem with wasserstein_distance

TypingError: Failed in nopython mode pipeline (step: nopython frontend) No implementation of function Function(<function wasserstein_distance at 0x7f4491cae700>) found for signature:

wasserstein_distance(array(float64, 1d, C), array(float64, 1d, C))

There are 2 candidate implementations:

During: resolving callee type: Function(<function wasserstein_distance at 0x7f4491cae700>) During: typing of call at /home/yjc/anaconda3/envs/cellblast/lib/python3.9/site-packages/Cell_BLAST/blast.py (229)

File "../../../anaconda3/envs/cellblast/lib/python3.9/site-packages/Cell_BLAST/blast.py", line 229: def npd_v1(

 return 0.5 * (
     scipy.stats.wasserstein_distance(
     ^

I ask chatgpt for help and it responds

Given the code you've provided, I can see that you're trying to use the scipy.stats.wasserstein_distance function within a function that is being compiled by Numba with nopython=True. As I explained earlier, Numba doesn't 't support all Python libraries or functions in nopython mode. scipy.stats.wasserstein_distance is one of those unsupported functions.

To resolve this, you will need to create your own implementation of the wasserstein_distance function that can be compiled by Numba. The Wasserstein distance, also known as the earth mover's distance, can be computed using the cumulative distributions of two different distributions .

Here is a basic implementation of the wasserstein_distance function that should work with Numba:

@numba.njit
def wasserstein_distance(u_values, v_values):
     u_values_sorted = np. sort(u_values)
     v_values_sorted = np.sort(v_values)
     return np.mean(np.abs(np.cumsum(u_values_sorted) - np.cumsum(v_values_sorted)))

Then, replace the scipy.stats.wasserstein_distance calls in your npd_v1 function with this new wasserstein_distance function:

@numba.jit(nopython=True, nogil=True, cache=True)
def npd_v1(
     x: np.ndarray,
     y: np.ndarray,
     x_posterior: np.ndarray,
     y_posterior: np.ndarray,
     eps: float = 0.0,
) -> np.ndarray:
     ...
     return 0.5 * (
         wasserstein_distance(
             xy_posterior1[: len(x_posterior)], xy_posterior1[-len(y_posterior) :]
         )
         + wasserstein_distance(
             xy_posterior2[: len(x_posterior)], xy_posterior2[-len(y_posterior) :]
         )
     )

Fortunately, I managed to run through this code, I want to check with more professional you whether it is correct. I'd be happy if this is correct and helps you

shlin0415 commented 7 months ago

Thank you for your issue. I also encountered the same problem and solved it through the method you provided.

Jeff1995 commented 1 month ago

@yujcccc Thank you for the recommendation! A similar approach has been accepted in the lastest 0.5.1 release.