scverse / scanpy

Single-cell analysis in Python. Scales to >1M cells.
https://scanpy.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.92k stars 602 forks source link

sc.pp.scale changes the adata.X values when run again. #2629

Closed alexandregrimaldi closed 1 year ago

alexandregrimaldi commented 1 year ago

Please make sure these conditions are met

What happened?

Hi! I am not sure if this is a bug... Every time I rescale the pbmc3k_processed matrix using it as input in the scanpy sc.pp.scale function, I get a very slightly different matrix in output, enough to generate a different UMAP with each run. But if I rewrite it using numpy in a simple function called my_scale_function it outputs the exact same matrix as the input, generating the same UMAP down the line...

Could someone explain to me what is happening? (Note: The matrix is not sparse)

Minimal code sample

import scanpy as sc
import numpy as np
### Loading and preprocessing data
adata = sc.datasets.pbmc3k_processed()

### Defining scale function
def mean_var(X, axis=0):
    mean = np.mean(X, axis=axis, dtype=np.float64)
    mean_sq = np.multiply(X, X).mean(axis=axis, dtype=np.float64)
    var = mean_sq - mean**2
    # enforce R convention (unbiased estimator) for variance
    var *= X.shape[axis] / (X.shape[axis] - 1)
    return mean, var
def my_scale_function(X, clip=False):
    mean, var = mean_var(X, axis=0)
    X -= mean
    std = np.sqrt(var)
    std[std == 0] = 1
    X /= std
    if clip:
        X = np.clip(X, -10, 10)
    return np.matrix(X)

### Scanpy scale vs my_scale_function
mtx = adata.X
from scipy.sparse import issparse
print("mtx is parse=" + str(issparse(np.matrix(mtx))) + "\n")
print("Rescaled with my_scale_function:")
mtx_rescaled = my_scale_function(mtx)
print((mtx == mtx_rescaled).all())
print("Rescaled with scanpy:")
mtx_rescaled = sc.pp.scale(mtx, zero_center=True, max_value=None, copy=True)
print(str((np.matrix(mtx) == mtx_rescaled).all())  + "\n")
print("\nOriginal matrix:")
print(mtx)
print("\nMatrix rescaled with scanpy:")
print(mtx_rescaled)

Error output

mtx is parse=False

Rescaled with my_scale_function:
True
Rescaled with scanpy:
False

Original matrix:
[[-1.71469614e-01 -2.82757759e-01 -4.95753549e-02 ... -1.02923915e-01
  -2.09179729e-01 -5.31203270e-01]
 [-2.14582354e-01 -3.75530124e-01 -6.44599497e-02 ... -2.92909533e-01
  -3.13310266e-01 -5.96654296e-01]
 [-3.76887709e-01 -2.97174782e-01 -6.94468468e-02 ... -1.70980677e-01
  -1.70931697e-01  1.37899971e+00]
 ...
 [-2.07089618e-01 -2.52101928e-01 -4.90629673e-02 ... -4.98141423e-02
  -1.61111996e-01  2.04149699e+00]
 [-1.90328494e-01 -2.27726802e-01 -4.46720645e-02 ...  1.15651824e-03
  -1.35240912e-01 -4.82111037e-01]
 [-3.33789378e-01 -2.55257130e-01 -6.06345981e-02 ... -8.05590525e-02
  -1.30351290e-01 -4.71337825e-01]]

Matrix rescaled with scanpy:
[[-1.7146961e-01 -2.8275776e-01 -4.9575359e-02 ... -1.0292391e-01
  -2.0917973e-01 -5.3120327e-01]
 [-2.1458235e-01 -3.7553012e-01 -6.4459950e-02 ... -2.9290953e-01
  -3.1331027e-01 -5.9665430e-01]
 [-3.7688771e-01 -2.9717478e-01 -6.9446847e-02 ... -1.7098068e-01
  -1.7093170e-01  1.3789997e+00]
 ...
 [-2.0708962e-01 -2.5210193e-01 -4.9062971e-02 ... -4.9814139e-02
  -1.6111200e-01  2.0414970e+00]
 [-1.9032849e-01 -2.2772680e-01 -4.4672068e-02 ...  1.1565228e-03
  -1.3524091e-01 -4.8211104e-01]
 [-3.3378938e-01 -2.5525713e-01 -6.0634598e-02 ... -8.0559045e-02
  -1.3035129e-01 -4.7133783e-01]]

Versions

``` ----- anndata 0.9.2 scanpy 1.9.3 ----- PIL 9.5.0 anyio NA arrow 1.2.3 asttokens NA attr 23.1.0 attrs 23.1.0 babel 2.12.1 backcall 0.2.0 certifi 2023.07.22 cffi 1.15.1 charset_normalizer 3.2.0 cloudpickle 2.2.1 combat NA comm 0.1.3 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 debugpy 1.6.7 decorator 5.1.1 decoupler 1.4.0 defusedxml 0.7.1 dill 0.3.5.1 executing 1.2.0 fastjsonschema NA fqdn NA gseapy 1.0.5 h5py 3.9.0 idna 3.4 igraph 0.10.6 ipykernel 6.25.0 isoduration NA jedi 0.19.0 jinja2 3.1.2 joblib 1.3.1 json5 NA jsonpointer 2.4 jsonschema 4.18.4 jsonschema_specifications NA jupyter_events 0.7.0 jupyter_server 2.7.0 jupyterlab_server 2.24.0 kiwisolver 1.4.4 leidenalg 0.10.1 liana 0.1.9 llvmlite 0.40.1 markupsafe 2.1.3 matplotlib 3.6.3 matplotlib_inline 0.1.6 mizani 0.9.2 mpl_toolkits NA mpmath 1.3.0 natsort 8.4.0 nbformat 5.9.2 numba 0.57.1 numpy 1.24.4 overrides NA packaging 23.1 pandas 1.5.3 parso 0.8.3 patsy 0.5.3 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA platformdirs 3.10.0 plotnine 0.12.2 prometheus_client NA prompt_toolkit 3.0.39 psutil 5.9.5 ptyprocess 0.7.0 pure_eval 0.2.2 pyarrow 12.0.1 pycparser 2.21 pydeseq2 0.3.5 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.15.1 pynndescent 0.5.10 pyparsing 3.1.1 pythonjsonlogger NA pytz 2023.3 referencing NA requests 2.31.0 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rnaxplorer NA rpds NA scipy 1.11.1 seaborn 0.12.2 send2trash NA session_info 1.0.0 sitecustomize NA six 1.16.0 sklearn 1.3.0 sniffio 1.3.0 sphinxcontrib NA stack_data 0.6.2 statsmodels 0.14.0 texttable 1.6.7 threadpoolctl 3.2.0 torch 1.13.1+cu117 tornado 6.3.2 tqdm 4.65.0 traitlets 5.9.0 typing_extensions NA umap 0.5.3 uri_template NA urllib3 2.0.4 wcwidth 0.2.6 webcolors 1.13 websocket 1.6.1 yaml 6.0.1 zmq 25.1.0 zoneinfo NA ----- IPython 8.14.0 jupyter_client 8.3.0 jupyter_core 5.3.1 jupyterlab 4.0.3 notebook 7.0.1 ----- Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] Linux-5.15.0-1033-gke-x86_64-with-glibc2.35 ----- Session information updated at 2023-08-21 12:14 ```
eroell commented 1 year ago

Hi, thanks for your interest in scanpy!

I’ll try to comment on your observations here with your code example:

import scanpy as sc
import numpy as np
### Loading and preprocessing data
adata = sc.datasets.pbmc3k_processed()

### Defining scale function
def mean_var(X, axis=0):
    mean = np.mean(X, axis=axis, dtype=np.float64)
    mean_sq = np.multiply(X, X).mean(axis=axis, dtype=np.float64)
    var = mean_sq - mean**2
    # enforce R convention (unbiased estimator) for variance
    var *= X.shape[axis] / (X.shape[axis] - 1)
    return mean, var

As a first note of caution, in your code your function actually modifies the original data matrix, of the scanpy object - which is used again later in the snippet. → We should create a copy of X. Else the code overwrites this object, and ends up comparing an object with itself, while simply using two names for it (this caused your == comparisons to evaluate as True, but is not what you intend to test).

def my_scale_function(X, clip=False):
    # need to make a copy of X
    Y = X.copy()
    mean, var = mean_var(Y, axis=0)
    Y -= mean
    std = np.sqrt(var)
    #std[std == 0] = 1
    Y /= std
    if clip:
        Y = np.clip(X, -10, 10)
    return np.matrix(Y)

As a second note of caution, floating point numbers should not be compared with the == operator (see for example here).

→ A more common way would be to use e.g. np.allclose() for this purpose.

### Scanpy scale vs my_scale_function.

print("Rescaled with my_scale_function:")
mtx_rescaled = my_scale_function(adata.X)

print("Do a numpy check for closeness of floats:")
print(np.allclose(adata.X, mtx_rescaled))
Do a numpy check for closeness of floats:
False

You can see that this test actually fails. This is because not all genes appear scaled, and your function now actually is doing that.

adata.X.var(0)
array([0.9996213 , 0.97964925, 0.29805112, ..., 0.78701097, 0.9980862 ,
       0.9996219 ], dtype=float32)

This could happen if e.g. cells were used to scale gene expression, which were later discarded in quality control. So when calling my_scale_function or sc.pp.scale, we expect the cell-by-gene matrix to change at first

mtx_rescaled_sc = sc.pp.scale(adata.X, copy=True)

print("Do a numpy check for closeness of floats:")
print(np.allclose(adata.X, mtx_rescaled_sc))
Do a numpy check for closeness of floats:
False

But not anymore if we call sc.pp.scale again.

mtx_rescaled_sc_II = sc.pp.scale(mtx_rescaled_sc, copy=True)

print("Do a numpy check for closeness of floats:")
print(np.allclose(mtx_rescaled_sc, mtx_rescaled_sc_II))
Do a numpy check for closeness of floats:
True

This is the behaviour which we would expect: I also think that the UMAPs generated should be reproducible. Hope this helps!

eroell commented 1 year ago

Thank you again for bringing up this issue!

Based on the provided information and the discussion so far, it seems that the question has been addressed.

However, please don't hesitate to reopen this issue or create a new one if you have any more questions or run into any related problems in the future.

Thanks for being a part of our community! :)