BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
313 stars 57 forks source link

Model loading fails on Apple Silicon Mac because of `device = 0` in `cpu` mode #332

Open LucaMarconato opened 11 months ago

LucaMarconato commented 11 months ago

Minimal code sample (that we can run without your data, using public data)

Sorry I don't have a minimal working example right now because I am working on a notebook, so I am for the moment reporting the bug in case someone else is also finding it.

I have trained and saved a model, and then I try to load it with

adata_vis = ad.read_h5ad("mouse_brain.h5ad")
mod = cell2location.models.Cell2location.load("cell2location_model", adata_vis)

I am using an Apple Silicon Mac with the option

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

The load call causes a problem with PyTorch Lightning that tries to use a CPUAccelerator with devices=0, while it should be a int > 0.

Here is the full traceback.

TypeError                                 Traceback (most recent call last)
Cell In[31], line 1
----> 1 mod = RegressionModel.load("data/reference_signatures", adata_ref)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/scvi/model/base/_base_model.py:721, in BaseModelClass.load(cls, dir_path, adata, use_gpu, accelerator, device, prefix, backup_url)
    716 getattr(cls, method_name)(
    717     adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]
    718 )
    720 model = _initialize_model(cls, adata, attr_dict)
--> 721 model.module.on_load(model)
    722 model.module.load_state_dict(model_state_dict)
    724 model.to_device(device)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/scvi/module/base/_base_module.py:394, in PyroBaseModuleClass.on_load(self, model)
    389 """Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.
    390 
    391 For some Pyro modules with AutoGuides, run one training step prior to loading state dict.
    392 """
    393 old_history = model.history_.copy()
--> 394 model.train(max_steps=1, **self.on_load_kwargs)
    395 model.history_ = old_history
    396 pyro.clear_param_store()

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/cell2location/models/reference/_reference_model.py:157, in RegressionModel.train(self, max_epochs, batch_size, train_size, lr, **kwargs)
    154 kwargs["train_size"] = train_size
    155 kwargs["lr"] = lr
--> 157 super().train(**kwargs)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py:174, in PyroSviTrainMixin.train(self, max_epochs, use_gpu, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, plan_kwargs, **trainer_kwargs)
    171     trainer_kwargs["callbacks"] = []
    172 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
--> 174 runner = self._train_runner_cls(
    175     self,
    176     training_plan=training_plan,
    177     data_splitter=data_splitter,
    178     max_epochs=max_epochs,
    179     use_gpu=use_gpu,
    180     accelerator=accelerator,
    181     devices=device,
    182     **trainer_kwargs,
    183 )
    184 return runner()

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/scvi/train/_trainrunner.py:85, in TrainRunner.__init__(self, model, training_plan, data_splitter, max_epochs, use_gpu, accelerator, devices, **trainer_kwargs)
     83 self.lightning_devices = lightning_devices
     84 self.device = device
---> 85 self.trainer = self._trainer_cls(
     86     max_epochs=max_epochs,
     87     accelerator=accelerator,
     88     devices=lightning_devices,
     89     **trainer_kwargs,
     90 )

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/scvi/train/_trainer.py:139, in Trainer.__init__(self, accelerator, devices, benchmark, check_val_every_n_epoch, max_epochs, default_root_dir, enable_checkpointing, num_sanity_val_steps, enable_model_summary, early_stopping, early_stopping_monitor, early_stopping_min_delta, early_stopping_patience, early_stopping_mode, enable_progress_bar, progress_bar_refresh_rate, simple_progress_bar, logger, log_every_n_steps, **kwargs)
    136 if logger is None:
    137     logger = SimpleLogger()
--> 139 super().__init__(
    140     accelerator=accelerator,
    141     devices=devices,
    142     benchmark=benchmark,
    143     check_val_every_n_epoch=check_val_every_n_epoch,
    144     max_epochs=max_epochs,
    145     default_root_dir=default_root_dir,
    146     enable_checkpointing=enable_checkpointing,
    147     num_sanity_val_steps=num_sanity_val_steps,
    148     enable_model_summary=enable_model_summary,
    149     logger=logger,
    150     log_every_n_steps=log_every_n_steps,
    151     enable_progress_bar=enable_progress_bar,
    152     **kwargs,
    153 )

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/pytorch/utilities/argparse.py:70, in _defaults_from_env_vars.<locals>.insert_env_defaults(self, *args, **kwargs)
     67 kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
     69 # all args were already moved to kwargs
---> 70 return fn(self, **kwargs)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:399, in Trainer.__init__(self, accelerator, strategy, devices, num_nodes, precision, logger, callbacks, fast_dev_run, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, overfit_batches, val_check_interval, check_val_every_n_epoch, num_sanity_val_steps, log_every_n_steps, enable_checkpointing, enable_progress_bar, enable_model_summary, accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm, deterministic, benchmark, inference_mode, use_distributed_sampler, profiler, detect_anomaly, barebones, plugins, sync_batchnorm, reload_dataloaders_every_n_epochs, default_root_dir)
    396 # init connectors
    397 self._data_connector = _DataConnector(self)
--> 399 self._accelerator_connector = _AcceleratorConnector(
    400     devices=devices,
    401     accelerator=accelerator,
    402     strategy=strategy,
    403     num_nodes=num_nodes,
    404     sync_batchnorm=sync_batchnorm,
    405     benchmark=benchmark,
    406     use_distributed_sampler=use_distributed_sampler,
    407     deterministic=deterministic,
    408     precision=precision,
    409     plugins=plugins,
    410 )
    411 self._logger_connector = _LoggerConnector(self)
    412 self._callback_connector = _CallbackConnector(self)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:157, in _AcceleratorConnector.__init__(self, devices, num_nodes, accelerator, strategy, plugins, precision, sync_batchnorm, benchmark, use_distributed_sampler, deterministic)
    154     self._accelerator_flag = self._choose_gpu_accelerator_backend()
    156 self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
--> 157 self._set_parallel_devices_and_init_accelerator()
    159 # 3. Instantiate ClusterEnvironment
    160 self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:390, in _AcceleratorConnector._set_parallel_devices_and_init_accelerator(self)
    382     raise MisconfigurationException(
    383         f"`{accelerator_cls.__qualname__}` can not run on your system"
    384         " since the accelerator is not available. The following accelerator(s)"
    385         " is available and can be passed into `accelerator` argument of"
    386         f" `Trainer`: {available_accelerator}."
    387     )
    389 self._set_devices_flag_if_auto_passed()
--> 390 self._devices_flag = accelerator_cls.parse_devices(self._devices_flag)
    391 if not self._parallel_devices:
    392     self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/pytorch/accelerators/cpu.py:48, in CPUAccelerator.parse_devices(devices)
     45 @staticmethod
     46 def parse_devices(devices: Union[int, str, List[int]]) -> int:
     47     """Accelerator device parsing logic."""
---> 48     return _parse_cpu_cores(devices)

File ~/miniconda3/envs/day4-YML/lib/python3.9/site-packages/lightning/fabric/accelerators/cpu.py:85, in _parse_cpu_cores(cpu_cores)
     82     cpu_cores = int(cpu_cores)
     84 if not isinstance(cpu_cores, int) or cpu_cores <= 0:
---> 85     raise TypeError("`devices` selected with `CPUAccelerator` should be an int > 0.")
     87 return cpu_cores

TypeError: `devices` selected with `CPUAccelerator` should be an int > 0.

Sorry for the lack of reproducibility of this bug, unfortunately I don't have time to make and it occurred as part of a notebook that I don't need. So I am just reporting to make the developers aware of this bug for Apple Silicon Macs and in case someone else wants to follow up.

LucaMarconato commented 11 months ago

Actually the bug triggers the error in pytorch_lightning code but seems to original from scvi code: at some points parse_device_args() from scvi/model/_utils.py returns the tuple

(_accelerator, _devices, device) == ('cpu', [0], torch.device('cpu'))

while I believe it should return

(_accelerator, _devices, device) == ('cpu', 1, torch.device('cpu'))
Scoott commented 5 months ago

I get the same error, @LucaMarconato did you ever find a solution to this? (I'm on an M2 Mac).

I trained the data no problem but when I load it as suggested in the tutorial with:

mod = cell2location.models.Cell2location.load(f'{run_name}', adata_vis)

The same 'TypeError' is produced.

vitkl commented 5 months ago

Does this happen in a clean conda environment with github version of cell2location?

Scoott commented 5 months ago

@vitkl thanks for the reply, I used a clean Conda env and installed cell2location and the problem was resolved.

I previously just pip installed cell2location and as Luca reported:

from scvi.model._utils import parse_device_args parse_device_args()

outputs: '('cpu', [0])'