Open viraj-rapolu opened 2 months ago
Hey Viraj -- sorry you ran into this bug! We haven't seen this before to my knowledge.
A couple thoughts:
doesn't actually support a device
kwarg in the config. It current assumes that you want to use a CUDA device if you have one available and moves everything over there. The device
kwarg won't throw an error, but it won't be used.config
kwarg on the API call only accepts strings for pre-registered configs (see scnym.api.CONFIGS.keys()
) or a full dictionary of configuration options that matches the format of those pre-registered options. Apologies, we never fully built out the config system for custom research, so everything outside the pre-registered configs is currently unsupported.import scnym
import scanpy as sc
import torch
adata = scanpy.datasets.pbmc3k()
# create some random class labels, the model should still hit high performance by overfitting
adata.obs["annotations"] = np.random.randint(0, 3, size=adata.shape[0])
# NOTE: scNym requires data to be log1p(CountsPerMillion), so we set `target_sum = 1e6` rather than the default `1e4`
sc.pp.normalize_total(adata, target_sum=int(1e6))
# cut the highly variable genes call -- scNym selects genes internally
# without subsetting the AnnData like
# `adata = adata[:, adata.var["highly_variable"]], this doesn't have the intended effect anyway
# sc.pp.highly_variable_genes(adata, n_top_genes=3000)
Hi Jacob,
Thanks for looking into this so quickly! I modified the preprocessing steps and used the parameters you suggested for the scnym_api function, but I'm still encountering the same error. I'm having this issue even if I don't explicitly set the torch.device().
I am encountering a RuntimeError related to mismatched tensor devices during training when using scnym_api with GPU. The error occurs when indexing tensors during the training process, and it seems like the tensors are not being moved to the same device (CPU vs GPU) internally.
I have ensured that my data and configuration are set correctly, and the issue persists even when explicitly specifying the device in the config. This appears to be related to device management within the scnym library itself.
An example of my code that produces the issue:
` import scnym import scanpy as sc import torch
adata = sc.read_h5ad('/path/to/dataset.h5ad')
sc.pp.normalize_total(adata) sc.pp.log1p(adata) sc.pp.highly_variable_genes(adata, n_top_genes=3000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scnym.scnym_api( adata=adata, task='train', groupby='annotations', domain_groupby='domain', out_path='/path/to/scnym_outs', config={'device': str(device), 'other_config': 'no_new_identity'}) ) `
Traceback (most recent call last): File "", line 102, in <module> config={'device': str(device), 'other_config': 'no_new_identity'}) File "/path/.local/lib/python3.7/site-packages/scnym/", line 341, in scnym_api config=config, File "/path/.local/lib/python3.7/site-packages/scnym/", line 538, in scnym_train **config['model_kwargs'], File "/path/.local/lib/python3.7/site-packages/scnym/", line 519, in fit_model T.train() File "/path/.local/lib/python3.7/site-packages/scnym/", line 452, in train self.train_epoch() File "/path/.local/lib/python3.7/site-packages/scnym/", line 585, in train_epoch labeled_sample=data, File "/path/.local/lib/python3.7/site-packages/scnym/", line 1491, in __call__ conf_unlabeled_sample[k] = unlabeled_sample[k][pseudolabel_confidence] RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
Environment information:
Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 1_gnu conda-forge absl-py 0.14.0 pypi_0 pypi anndata 0.7.6 pypi_0 pypi ca-certificates 2021.5.30 ha878542_0 conda-forge cachetools 4.2.4 pypi_0 pypi certifi 2021.5.30 pypi_0 pypi charset-normalizer 2.0.6 pypi_0 pypi configargparse 1.5.2 pypi_0 pypi cycler 0.10.0 pypi_0 pypi decorator 5.1.0 pypi_0 pypi filelock 3.15.4 pypi_0 pypi fsspec 2024.6.1 pypi_0 pypi google-auth 1.35.0 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi grpcio 1.41.0 pypi_0 pypi h5py 3.4.0 pypi_0 pypi idna 3.2 pypi_0 pypi jinja2 3.1.4 pypi_0 pypi joblib 1.0.1 pypi_0 pypi kiwisolver 1.3.2 pypi_0 pypi ld_impl_linux-64 2.36.1 hea4e1c9_2 conda-forge leidenalg 0.8.7 pypi_0 pypi libffi 3.4.2 h9c3ff4c_4 conda-forge libgcc-ng 11.2.0 h1d223b6_9 conda-forge libgomp 11.2.0 h1d223b6_9 conda-forge libstdcxx-ng 11.2.0 he4da1e4_9 conda-forge libzlib 1.2.11 h36c2ea0_1012 conda-forge llvmlite 0.37.0 pypi_0 pypi markdown 3.3.4 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi matplotlib 3.4.3 pypi_0 pypi mock 4.0.3 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi natsort 7.1.1 pypi_0 pypi ncurses 6.2 h58526e2_4 conda-forge networkx 2.6.3 pypi_0 pypi numba 0.54.0 pypi_0 pypi numexpr 2.7.3 pypi_0 pypi numpy 1.20.3 pypi_0 pypi nvidia-cublas-cu12 pypi_0 pypi nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi nvidia-cudnn-cu12 pypi_0 pypi nvidia-cufft-cu12 pypi_0 pypi nvidia-curand-cu12 pypi_0 pypi nvidia-cusolver-cu12 pypi_0 pypi nvidia-cusparse-cu12 pypi_0 pypi nvidia-nccl-cu12 2.20.5 pypi_0 pypi nvidia-nvjitlink-cu12 12.6.20 pypi_0 pypi nvidia-nvtx-cu12 12.1.105 pypi_0 pypi oauthlib 3.1.1 pypi_0 pypi openssl 3.0.0 h7f98852_1 conda-forge packaging 21.0 pypi_0 pypi pandas 1.3.3 pypi_0 pypi patsy 0.5.2 pypi_0 pypi pillow 10.4.0 pypi_0 pypi pip 21.2.4 pyhd8ed1ab_0 conda-forge protobuf 3.18.0 pypi_0 pypi pyasn1 0.4.8 pypi_0 pypi pyasn1-modules 0.2.8 pypi_0 pypi pynndescent 0.5.4 pypi_0 pypi pyparsing 2.4.7 pypi_0 pypi python 3.9.7 hf930737_3_cpython conda-forge python-dateutil 2.8.2 pypi_0 pypi python-igraph 0.9.6 pypi_0 pypi python_abi 3.9 2_cp39 conda-forge pytz 2021.1 pypi_0 pypi pyyaml 5.4.1 pypi_0 pypi readline 8.1 h46c0cb4_0 conda-forge requests 2.26.0 pypi_0 pypi requests-oauthlib 1.3.0 pypi_0 pypi rsa 4.7.2 pypi_0 pypi scanpy 1.8.1 pypi_0 pypi scikit-learn 1.0 pypi_0 pypi scipy 1.7.1 pypi_0 pypi scnym 0.1.11 pypi_0 pypi seaborn 0.11.2 pypi_0 pypi setuptools 58.0.4 py39hf3d152e_2 conda-forge sinfo 0.3.4 pypi_0 pypi six 1.16.0 pypi_0 pypi sqlite 3.36.0 h9cd32fc_2 conda-forge statsmodels 0.13.0rc0 pypi_0 pypi stdlib-list 0.8.0 pypi_0 pypi sympy 1.13.1 pypi_0 pypi tables 3.6.1 pypi_0 pypi tensorboard 2.6.0 pypi_0 pypi tensorboard-data-server 0.6.1 pypi_0 pypi tensorboard-plugin-wit 1.8.0 pypi_0 pypi texttable 1.6.4 pypi_0 pypi threadpoolctl 3.0.0 pypi_0 pypi tk 8.6.11 h27826a3_1 conda-forge torch 2.4.0 pypi_0 pypi torchvision 0.19.0 pypi_0 pypi tqdm 4.62.3 pypi_0 pypi triton 3.0.0 pypi_0 pypi typing-extensions 4.12.2 pypi_0 pypi tzdata 2021a he74cb21_1 conda-forge umap-learn 0.5.1 pypi_0 pypi urllib3 1.26.7 pypi_0 pypi werkzeug 2.0.1 pypi_0 pypi wheel 0.37.0 pyhd8ed1ab_1 conda-forge xlrd 1.2.0 pypi_0 pypi xz 5.2.5 h516909a_1 conda-forge zlib 1.2.11 h36c2ea0_1012 conda-forge