snap-stanford / UCE

UCE is a zero-shot foundation model for single-cell gene expression data
MIT License
120 stars 15 forks source link

NaN error during embeddings #18

Closed DanLesman closed 4 months ago

DanLesman commented 5 months ago

I encountered the following error during training. I'm embeddings ~350k cells.

Also curious if there is a way to save intermediate values so that for large datasets that crash the embedding would not have to be restarted.

Thank you!

74%|███████▍ | 10476/14173 [8:42:49<2:59:51, 2.92s/it] 74%|███████▍ | 10477/14173 [8:42:52<2:59:34, 2.92s/it] 74%|███████▍ | 10477/14173 [8:42:52<3:04:27, 2.99s/it] Traceback (most recent call last): File "/UCE/eval_single_anndata.py", line 155, in main(args, accelerator) File "/UCE/eval_single_anndata.py", line 85, in main processor.run_evaluation() File "/UCE/evaluate.py", line 145, in run_evaluation run_eval(self.adata, self.name, self.pe_idx_path, self.chroms_path, File "/UCE/evaluate.py", line 235, in run_eval for batch in pbar: File "/UCE_env/lib/python3.9/site-packages/tqdm/std.py", line 1182, in iter for obj in iterable: File "/UCE_env/lib/python3.9/site-packages/accelerate/data_loader.py", line 461, in iter next_batch = next(dataloader_iter) File "/UCE_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 633, in next data = self._next_data() File "/UCE_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "UCE_env/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/UCE_env/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/UCE/eval_data.py", line 62, in getitem sample_cell_sentences(counts, weights, dataset, self.args, File "/UCE/eval_data.py", line 126, in sample_cell_sentences choice_idx = np.random.choice(np.arange(len(weights)), File "numpy/random/mtrand.pyx", line 971, in numpy.random.mtrand.RandomState.choice ValueError: probabilities contain NaN

marcelroed commented 5 months ago

Hey Dan! Can you give us the result of pip freeze in your Python environment?

DanLesman commented 5 months ago

Hey Marcel! Here it is:

accelerate==0.26.1 accumulation-tree==0.6.2 anndata==0.10.4 anyio==4.2.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 array-api-compat==1.4 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 attrs==23.2.0 Babel==2.14.0 backoff==2.2.1 beautifulsoup4==4.12.3 bleach==6.1.0 blessed==1.20.0 Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work certifi @ file:///croot/certifi_1700501669400/work/certifi cffi @ file:///croot/cffi_1700254295673/work chardet==5.2.0 charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click==8.1.7 comm==0.2.1 contextlib2==21.6.0 contourpy==1.2.0 cryptography @ file:///croot/cryptography_1702070282333/work cycler==0.12.1 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 dm-tree==0.1.8 docrep==0.3.2 etils==1.5.2 exceptiongroup==1.2.0 executing==2.0.1 fastjsonschema==2.19.1 filelock @ file:///croot/filelock_1700591183607/work fonttools==4.47.2 fqdn==1.5.1 fsspec==2023.12.2 get-annotations==0.1.2 gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645438755360/work h11==0.14.0 h5py==3.10.0 huggingface-hub==0.20.3 idna @ file:///croot/idna_1666125576474/work importlib-metadata==7.0.1 importlib-resources==6.1.1 ipykernel==6.29.0 ipython==8.18.1 isoduration==20.11.0 jedi==0.19.1 Jinja2 @ file:///croot/jinja2_1666908132255/work joblib==1.3.2 json5==0.9.14 jsonpointer==2.4 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 jupyter-events==0.9.0 jupyter-lsp==2.2.2 jupyter_client==8.6.0 jupyter_core==5.7.1 jupyter_server==2.12.5 jupyter_server_terminals==0.5.2 jupyterlab==4.0.11 jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.2 kaggle==1.5.16 kiwisolver==1.4.5 llvmlite==0.41.1 loompy==3.0.7 MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib==3.8.2 matplotlib-inline==0.1.6 mdurl==0.1.2 mistune==3.0.2 mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 mpmath @ file:///croot/mpmath_1690848262763/work natsort==8.4.0 nbclient==0.9.0 nbconvert==7.14.2 nbformat==5.9.2 nest-asyncio==1.6.0 networkx @ file:///croot/networkx_1690561992265/work notebook==7.0.7 notebook_shim==0.2.3 numba==0.58.1 numpy @ file:///croot/numpy_and_numpy_base_1704311704800/work/dist/numpy-1.26.3-cp39-cp39-linux_x86_64.whl#sha256=93e7b9e5e2090dd03810e7c1b02cb077d3ef49fc713f0af531e0667375d9decb numpy-groupies==0.10.2 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 ordered-set==4.1.0 overrides==7.6.0 packaging==23.2 pandas==2.2.0 pandocfilters==1.5.1 parso==0.8.3 patsy==0.5.6 pexpect==4.9.0 Pillow @ file:///croot/pillow_1696580024257/work platformdirs==4.1.0 prometheus-client==0.19.0 prompt-toolkit==3.0.43 protobuf==4.25.1 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work Pygments==2.17.2 pynndescent==0.5.11 pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work pyparsing==3.1.1 pyro-api==0.1.2 PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work python-dateutil==2.8.2 python-editor==1.0.4 python-json-logger==2.0.7 python-multipart==0.0.6 python-slugify==8.0.1 pytz==2023.3.post1 pyudorandom==1.0.0 PyYAML==6.0.1 pyzmq==25.1.2 readchar==4.0.5 referencing==0.32.1 requests @ file:///croot/requests_1690400202158/work rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rpds-py==0.17.1 safetensors==0.4.2 scanpy==1.9.6 scikit-learn==1.4.0 scipy==1.12.0 seaborn==0.13.1 Send2Trash==1.8.2 session-info==1.0.0 six==1.16.0 sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.3 statsmodels==0.14.1 stdlib-list==0.10.0 sympy @ file:///croot/sympy_1701397643339/work tdigest==0.5.2.2 terminado==0.18.0 text-unidecode==1.3 threadpoolctl==3.2.0 tinycss2==1.2.1 tomli==2.0.1 torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 tornado==6.4 tqdm==4.66.1 traitlets==5.14.0 triton==2.0.0 types-python-dateutil==2.8.19.20240106 typing_extensions==4.8.0 tzdata==2023.4 umap-learn==0.5.5 uri-template==1.3.0 urllib3 @ file:///croot/urllib3_1698257533958/work wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 websockets==12.0 zipp==3.17.0

Yanay1 commented 5 months ago

Hi Dan,

The issue with probabilities means that for that specific cell, after filtering to genes with protein embeddings, there were 0 genes with expression.

To get around that, you should try and filter the dataset by minimum number of genes expressed per cell. A value like 40 would probably work. You can use: https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.filter_cells.html

This will require you to delete the intermediate files saved by UCE for that adata.

Unfortunately there isn't currently a way to save intermediate results, but you could split the adata in half and save it as two different files, and then evaluate those two halves separately.

DanLesman commented 5 months ago

Thank you! I will do the filtering on the subset of genes that UCE considers (which appears to be a rather small subset and may be the source of the problem)

Yanay1 commented 5 months ago

There might be some issue with matching gene names-- for most datasets there should be more than 10,000 genes matched.

DanLesman commented 5 months ago

Will look into it. Thank you!