BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

Performance reduction when training the Cell2location model with CPU in the latest version #152

Closed tomasgomes closed 1 year ago

tomasgomes commented 2 years ago

Hello! We (@AMaynard10 and I) have been running successfully cell2location with snRNA-seq and Visium data, solely using CPUs. The slower steps when doing this are training the regression and Cel2location models, each taking ~1-3h with our data, using version 0.6a0 and scvi-tools version 0.14.3.

Recently we have tried re-running our code (very similar to the tutorial) with the updated cell2location and scvi-tools (versions 0.8a0 and 0.16.1, respectively). For the same datasets, the regression training took approximately the same time, but training the Cell2location model became up to 10 times slower (taking up to 40h), using the same system, at similar times (even simultaneously).

The code below is what we've used to train the Cell2location model in both versions, and almost directly taken from the tutorial.

# create and train the model
mod2 = cell2location.models.Cell2location(
    vis, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=25,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection (using default here):
    detection_alpha=200
)

mod2.train(max_epochs=15000,
          # train using full data (batch_size=None)
          batch_size=None,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size=1,
          use_gpu=False)

# plot ELBO loss history during training, removing first 100 epochs from the plot
mod2.plot_history(1000)
plt.legend(labels=['full data training']);

Are there any options that allow us to tune/improve CPU usage?

Thanks in advance and for the very useful tool!

adamgayoso commented 2 years ago

This is interesting, @vitkl did you add anything new?

@tomasgomes can you try installing the latest version of scvi-tools (https://github.com/scverse/scvi-tools) on the master branch?

It will work with cell2location just fine, I'm wondering if the issue is related to pytorch lightning

vitkl commented 2 years ago

Hi @tomasgomes

Thanks for reporting here!

I indeed have not introduced any changes to the model or inference. The difference probably comes from recent changes to pyro, scvi-tools, pytorch lightning. Could you try what @adamgayoso suggested (also reinstalling the latest v0.1 cell2location)?

vitkl commented 2 years ago

Please try installing the just released cell2location version (v0.1) as described here https://github.com/BayraktarLab/cell2location#installation Let's see if the speed issue still persists and I will try looking into pyro as a potential reason.

medinaserpas commented 2 years ago

It may be worth adding to the discussion, I was in a similar position as @tomasgomes. Up until last week I ran c2l v0.6a0 and updated to v0.1 running only on CPU. I found that if I train a new regression model and use that for cell2location modelling I end up with reduced performance, but not the extent you've experienced (maybe 2x longer). However, if I load an already trained regression model and use that for cell2location I end with similar performance as the pre-release version. Oddly enough, since updating to v0.1, I have not been able to load my cell2location models that were generated with the prerelease version.. You may have similar luck if you're able to use a previously trained regression model to apply cell2locaiton to your query data.

adamgayoso commented 2 years ago

However, if I load an already trained regression model and use that for cell2location I end with similar performance as the pre-release version.

@vitkl this suggests something other than scvi-tools to me

Oddly enough, since updating to v0.1, I have not been able to load my cell2location models that were generated with the prerelease version.

@medinaserpas what's the error you're facing? Nothing should be affected here? cc @jjhong922

medinaserpas commented 2 years ago

@adamgayoso This is the error. Its informative enough, I just haven't had the time to fully figure it out. The 'missing' file is in the directory and I haven't had time to fully work out the legacy_conversion at the end. So in lieu of correcting, I rerun the workflow and everything works just fine if we disregard the performance loss during model building. Worth noting, I work on an Apple M1 silicon device, its caused me issues with other packages so it may be contributing here, although I doubt it in this instance.

Input:

##Reload the model you created with the following
adata_file = "./cell2location_map/sp.h5ad"
adata_vis = sc.read_h5ad(adata_file)
mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

Output:


FileNotFoundError                         Traceback (most recent call last)
File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/scvi/model/base/_utils.py:66, in _load_saved_files(dir_path, load_adata, prefix, map_location, backup_url)
     65     _download(backup_url, dir_path, model_file_name)
---> 66     model = torch.load(model_path, map_location=map_location)
     67 except FileNotFoundError as exc:

File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/torch/serialization.py:699, in load(f, map_location, pickle_module, **pickle_load_args)
    697     pickle_load_args['encoding'] = 'utf-8'
--> 699 with _open_file_like(f, 'rb') as opened_file:
    700     if _is_zipfile(opened_file):
    701         # The zipfile reader is going to advance the current file position.
    702         # If we want to actually tail call to torch.jit.load, we need to
    703         # reset back to the original position.

File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/torch/serialization.py:231, in _open_file_like(name_or_buffer, mode)
    230 if _is_path(name_or_buffer):
--> 231     return _open_file(name_or_buffer, mode)
    232 else:

File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/torch/serialization.py:212, in _open_file.__init__(self, name, mode)
    211 def __init__(self, name, mode):
--> 212     super(_open_file, self).__init__(open(name, mode))

FileNotFoundError: [Errno 2] No such file or directory: '/Users/mmedina/Documents/Work_Lab/Data/results/MM009_nPOD6560_pLN/cell2location_map/model.pt'

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Input In [3], in <cell line: 4>()
      2 adata_file = "./cell2location_map/sp.h5ad"
      3 adata_vis = sc.read_h5ad(adata_file)
----> 4 mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/scvi/model/base/_base_model.py:586, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
    583 load_adata = adata is None
    584 use_gpu, device = parse_use_gpu_arg(use_gpu)
--> 586 (attr_dict, var_names, model_state_dict, new_adata,) = _load_saved_files(
    587     dir_path,
    588     load_adata,
    589     map_location=device,
    590     prefix=prefix,
    591     backup_url=backup_url,
    592 )
    593 adata = new_adata if new_adata is not None else adata
    594 _validate_var_names(adata, var_names)

File ~/opt/miniconda3/envs/spatialtools/lib/python3.9/site-packages/scvi/model/base/_utils.py:68, in _load_saved_files(dir_path, load_adata, prefix, map_location, backup_url)
     66     model = torch.load(model_path, map_location=map_location)
     67 except FileNotFoundError as exc:
---> 68     raise ValueError(
     69         f"Failed to load model file at {model_path}. "
     70         "If attempting to load a saved model from <v0.15.0, please use the util function "
     71         "`convert_legacy_save` to convert to an updated format."
     72     ) from exc
     74 model_state_dict = model["model_state_dict"]
     75 var_names = model["var_names"]

ValueError: Failed to load model file at /Users/mmedina/Documents/Work_Lab/Data/results/MM009_nPOD6560_pLN//cell2location_map/model.pt. If attempting to load a saved model from <v0.15.0, please use the util function `convert_legacy_save` to convert to an updated format.
vitkl commented 2 years ago

I found that if I train a new regression model and use that for cell2location modelling I end up with reduced performance

Do you mean loadings regression model output, then creating a new instance of cell2location and training this new instance? The saved regression model class instance should be completely independent from newly created cell2location class instance. Maybe loading the model affects which package version are loaded?

In principle, you don't actually need to load the regression model - you can just read the H5AD object with results (or simply save reference expression signatures as a data frame). If you don't load the regression model and don't train the regression model - do you get the same reduced performance from cell2location model?

vitkl commented 2 years ago

Hi @tomasgomes

Did you solve this?

vitkl commented 1 year ago

I am assuming that this is resolved.