theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
323 stars 50 forks source link

Allow pickling from GPU to CPU #223

Closed maarten-devries closed 6 months ago

maarten-devries commented 6 months ago

Use case I want to train a reference model on GPU. Then I want to map a query onto the reference using only a CPU. I know the map_location parameter exists, but it doesn't fully resolve the issue.

Issue In particular, below is the stacktrace for one error I'm encountering when trying to do the above.

Solution I basically just grabbed a solution posted here and applied it to the scarches codebase.

    557     return scPoli.load(dir_path=dir_path, adata=adata_ref)
    558 else:
--> 559     return scPoli.load(dir_path=dir_path, adata=adata_ref, map_location=torch.device("cpu"))

File /srv/conda/envs/saturn/lib/python3.9/site-packages/scarches/models/base/_base.py:170, in BaseMixin.load(cls, dir_path, adata, map_location)
    167 elif not os.path.exists(adata_path) and load_adata:
    168     raise ValueError("Save path contains no saved anndata and no adata was passed.")
--> 170 attr_dict, model_state_dict, var_names = cls._load_params(dir_path, map_location)
    172 # Overwrite adata with new genes
    173 adata = _validate_var_names(adata, var_names)

File /srv/conda/envs/saturn/lib/python3.9/site-packages/scarches/models/base/_base.py:131, in BaseMixin._load_params(cls, dir_path, map_location)
    128 varnames_path = os.path.join(dir_path, "var_names.csv")
    130 with open(setup_dict_path, "rb") as handle:
--> 131     attr_dict = pickle.load(handle)
    133 model_state_dict = torch.load(model_path, map_location=map_location)
    135 var_names = np.genfromtxt(varnames_path, delimiter=",", dtype=str)

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/storage.py:240, in _load_from_bytes(b)
    239 def _load_from_bytes(b):
--> 240     return torch.load(io.BytesIO(b))

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:795, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
    793     except RuntimeError as e:
    794         raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
--> 795 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:1012, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1010 unpickler = UnpicklerWrapper(f, **pickle_load_args)
   1011 unpickler.persistent_load = persistent_load
-> 1012 result = unpickler.load()
   1014 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
   1016 offset = f.tell() if f_should_read_directly else None

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:958, in _legacy_load.<locals>.persistent_load(saved_id)
    954     obj._torch_load_uninitialized = True
    955     # TODO: Once we decide to break serialization FC, we can
    956     # stop wrapping with TypedStorage
    957     deserialized_objects[root_key] = torch.storage.TypedStorage(
--> 958         wrap_storage=restore_location(obj, location),
    959         dtype=dtype)
    961 typed_storage = deserialized_objects[root_key]
    962 if view_metadata is not None:

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:215, in default_restore_location(storage, location)
    213 def default_restore_location(storage, location):
    214     for _, _, fn in _package_registry:
--> 215         result = fn(storage, location)
    216         if result is not None:
    217             return result

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:182, in _cuda_deserialize(obj, location)
    180 def _cuda_deserialize(obj, location):
    181     if location.startswith('cuda'):
--> 182         device = validate_cuda_device(location)
    183         if getattr(obj, "_torch_load_uninitialized", False):
    184             with torch.cuda.device(device):

File /srv/conda/envs/saturn/lib/python3.9/site-packages/torch/serialization.py:166, in validate_cuda_device(location)
    163 device = torch.cuda._utils._get_device_index(location, True)
    165 if not torch.cuda.is_available():
--> 166     raise RuntimeError('Attempting to deserialize object on a CUDA '
    167                        'device but torch.cuda.is_available() is False. '
    168                        'If you are running on a CPU-only machine, '
    169                        'please use torch.load with map_location=torch.device(\'cpu\') '
    170                        'to map your storages to the CPU.')
    171 device_count = torch.cuda.device_count()
    172 if device >= device_count:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Koncopd commented 6 months ago

Thank you. I will check the PR. However it might be better to find where it actually pickles something pytorch related. I will try to figure out soon.