scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.25k stars 352 forks source link

batch and max_kl_weight parameters ignored when mapping to scvi reference #2331

Closed LisaSikkema closed 1 month ago

LisaSikkema commented 11 months ago

Hi! I have been trying to map data to scvi-integrated embeddings using scArches, and have noticed that for the two scVI-based reference models I have used, setting the max_kl_weight parameter differently has no effect on the output. The same holds for changing the batch assignment of the query cells (e.g. using sample rather than dataset as batch covariate). I do not see the same problem with an scANVI-model I have used.

Here's a reproducible example of the first issue:

import scarches as sca
import scanpy as sc
ref_model_dir = "./ref_model"
adata_query_unprepped = sc.read_h5ad("./PeerMassague2020_subset_for_testing.h5ad")
mapped_embeddings = dict()
kl_weights_to_test = [0.1, 1, 2]
# map with three different max_kl_weights:
for max_kl_weight in kl_weights_to_test:
    adata_query = sca.models.SCVI.prepare_query_anndata(
        adata=adata_query_unprepped, reference_model=ref_model_dir, inplace=False
    )
    surgery_model = sca.models.SCVI.load_query_data(
        adata_query,
        ref_model_dir,
        freeze_dropout=True,
    )
    surgery_model.train(plan_kwargs={"weight_decay": 0.0, "max_kl_weight":max_kl_weight})
    mapped_embeddings[max_kl_weight] = sc.AnnData(surgery_model.get_latent_representation(adata_query))
# compare output embeddings:
print((mapped_embeddings[0.1].X == mapped_embeddings[1].X).all())
print((mapped_embeddings[0.1].X == mapped_embeddings[2].X).all())

Which outputs the following (most importantly, 2x "True" at the bottom):

INFO     File ./ref_model/model.pt already downloaded                                                              
INFO     Found 100.0% reference vars in query data.                                                                
INFO     File ./ref_model/model.pt already downloaded                                                              
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  8.99it/s, loss=2.95e+03, v_num=1]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=400` reached.
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  8.89it/s, loss=2.95e+03, v_num=1]
INFO     File ./ref_model/model.pt already downloaded                                                              
INFO     Found 100.0% reference vars in query data.                                                                
INFO     File ./ref_model/model.pt already downloaded                                                              
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  9.08it/s, loss=2.96e+03, v_num=1]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=400` reached.
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  9.06it/s, loss=2.96e+03, v_num=1]
INFO     File ./ref_model/model.pt already downloaded                                                              
INFO     Found 100.0% reference vars in query data.                                                                
INFO     File ./ref_model/model.pt already downloaded                                                              
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 400/400: 100%|██████████| 400/400 [00:43<00:00,  9.24it/s, loss=2.99e+03, v_num=1]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=400` reached.
Epoch 400/400: 100%|██████████| 400/400 [00:43<00:00,  9.26it/s, loss=2.99e+03, v_num=1]
True
True

and of the second issue:

print("batch key for model:", surgery_model.registry_['setup_args']['batch_key'])

batch_keys = ['sample','dataset']
mapped_embeddings_batch = dict()
# map using two different batch covariates (sample vs dataset)
for batch_covariate in batch_keys:
    adata_query_unprepped.obs['batch'] = adata_query_unprepped.obs[batch_covariate]
    adata_query = sca.models.SCVI.prepare_query_anndata(
        adata=adata_query_unprepped, reference_model=ref_model_dir, inplace=False
    )
    surgery_model = sca.models.SCVI.load_query_data(
        adata_query,
        ref_model_dir,
        freeze_dropout=True,
    )
    surgery_model.train(plan_kwargs={"weight_decay": 0.0})
    mapped_embeddings_batch[batch_covariate] = sc.AnnData(surgery_model.get_latent_representation(adata_query))
    # check if outcome is different:
print((mapped_embeddings_batch['sample'].X == mapped_embeddings_batch['dataset'].X).all())

which outputs (again most importantly: "True" at the bottom)

batch key for model: batch
INFO     File ./ref_model/model.pt already downloaded                                                              
INFO     Found 100.0% reference vars in query data.                                                                
INFO     File ./ref_model/model.pt already downloaded                                                              
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  8.96it/s, loss=2.97e+03, v_num=1]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=400` reached.
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  9.00it/s, loss=2.97e+03, v_num=1]
INFO     File ./ref_model/model.pt already downloaded                                                              
INFO     Found 100.0% reference vars in query data.                                                                
INFO     File ./ref_model/model.pt already downloaded                                                              
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_ ...
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/lisa.sikkema/miniconda3/envs/HLCA_mapping_env_new/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  9.05it/s, loss=2.97e+03, v_num=1]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=400` reached.
Epoch 400/400: 100%|██████████| 400/400 [00:44<00:00,  9.03it/s, loss=2.97e+03, v_num=1]
True

Versions:

scArches: '0.5.8' scvi-tools: '0.20.1'

The reference model I used is public, I downloaded it with this link: https://zenodo.org/records/10139343/files/pb_tissue_normal_ref_scvi.zip?download=1

This is the query dataset I use in the example: PeerMassague2020_subset_for_testing.h5ad.zip

Any idea why max_kl_weight and the batch covariate have zero effect on the output? This should not be the case, as far as I understand.

martinkim0 commented 11 months ago

Hi, thank you for your question. This is because, based on our current defaults, the scVI encoder does not receive batch assignments (encode_covariates=False) such that changing the batch covariate will not affect the latent representation.

In addition, scArches by default freezes pretrained model parameters (i.e. parameters not related to covariates), which is why training the reference model on the query data with different max KL weights does not change the latent representation (see the last section in our user guide).

These should both, however, have an effect on the decoder output (normalized expression).

LisaSikkema commented 11 months ago

Thanks a lot for the quick response!

I'm not sure if I understand though: I get that the model parameters for calculating the reference embedding aren't changed, but this should still allow for changing parameters specific for the query right? That's also how I read the user guide you linked to:

the training of the model with the query data is performed with respect to only the new query-category specific parameters.

As for the batch covariate not affecting the latent representation of the query: how could we still expect the mapping to perform any batch correction in that case?

Is there any way in which I can still change these settings at the stage of mapping, or should it have been done already when the reference model was trained?

martinkim0 commented 11 months ago

I get that the model parameters for calculating the reference embedding aren't changed, but this should still allow for changing parameters specific for the query right?

Right - since the only query-specific parameters added during transfer learning (by default) are the parameters accommodating the new batch covariate categories, these will be updated. Admittedly, this seems a little limited since we end up only updating the decoder, so not exactly sure why we can expect the model to batch correct query data well. But this is the default we have and are currently not changing it due to backwards compatibility.

Pre reference model training, you can change this behavior by specifying encoder_covariates=True during model initialization, the result being that part of the encoder (and thus the latent representation) being updated during additional query training.

Post reference model training, we have several options in load_query_data that lets you specify additional parameters to freeze/unfreeze, with unfrozen=True letting you train all model parameters. For more custom parameter freezing, you would have to manually put in zero grad hooks in the PyTorch module, sorry. Is there a particular part of the model you are interested in updating during query training?

LisaSikkema commented 11 months ago

Okay interesting! It actually sounds like the current default defeats the purpose of scArches as a whole, with no batch correction being performed during the mapping (at least at the level of the latent embedding). Probably also relevant for @M0hammadL to be aware of. Maybe good to at least stress this very clearly this in the scVI tutorials, to explain that query-to-reference mapping is only really possible later on if encoder_covariates=True has been set during reference model training.

I'm not so much interested in re-training the reference model, but more in being able to set the way the query model is trained (i.e. including batch already in the encoder, and changing the max_kl_weight setting). If I understand the documentation correctly, that is not really possible with the freeze/unfreeze parameters in load_query_data is it?

martinkim0 commented 11 months ago

Good point - probably a good idea to point this out in the tutorials. I wasn't involved in writing this part of scvi-tools, so I'll also double check that this is actually what is happening, but from what I recall, this is the case with our defaults.

Yeah, encoding batch information when the reference model has already been trained without it is not possible yet. For max KL weight, I do believe the changes should be reflected during query training, but let me double check.

canergen commented 11 months ago

Getting into the thread here, in all scArches tutorials it's highlighted that covariates are encoded. In our hands it doesn't make a difference for most datasets. OOD things shouldn't be done (like taking an sc model and using sn query). However, generally when reusing a model it always makes sense to check that integration was working and posterior predictive checks look good (in short reconstruction loss in query and reference is similar). This will be stressed in our publication encompassing scVI-criticism.

LisaSikkema commented 11 months ago

Thanks for your responses!

Right yes in the scArches tutorials the covariates are set that way (although not explained as far as I remember, so it might not be clear to users that those settings are important not to change), but they're not set that way nor discussed in the scVI /scANVI tutorials I think. Most people won't go through an scArches tutorial already when they're still integrating the data, and will only realize that their model isn't really scArches-friendly once they have already done a lot of quality checks and downstream analysis on their integration, and don't want to re-do it anymore. I have already experienced it with two large atlas integrations that I worked with in the past weeks.

Also, for me it makes quite a large difference to set parameters differently in the models where this is possible, especially setting the batch covariate differently, but also e.g. the KL divergence, so I wouldn't say it does not make much of a difference whether covariates are encoded.

canergen commented 1 month ago

We added a comment to the tutorial to highlight this.