BayraktarLab / cell2fate

Inference of RNA velocity modules for prediction of cell fates and integration with spatial and regulatory models.
Apache License 2.0
36 stars 3 forks source link

Batch export for posteriors #2

Open sezginerr opened 1 year ago

sezginerr commented 1 year ago

Hello,

I have been trying to export posteriors batch by batch with a defined batch size, rather than exporting them in a single batch. This is achieved by concatenating local variables exported from batches along the cell dimension in the scvi library. I believe the scvi library assumes that the variables defined with 'obs_plate' are the local variables, while the global variables (any other variable) are sampled from the last batch. However, for the cell2fate model, local and global variables seem to the same, as the guide method in the cell2fate model is an instance of 'poutine.messenger.Messenger'. In this case, posterior sampling does not use the 'return_sites' information and seems to return everything. Here is the function that returns one posterior sampling in the scvi library:

def _get_one_posterior_sample(
    self,
    args,
    kwargs,
    return_sites: Optional[list] = None,
    return_observed: bool = False,
  ):
    if isinstance(self.module.guide, poutine.messenger.Messenger):
        # This already includes trace-replay behavior.
        sample = self.module.guide(*args, **kwargs)
    else:
        guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.module.model, guide_trace)
        ).get_trace(*args, **kwargs)
sample = {
            name: site["value"]
            for name, site in model_trace.nodes.items()
            if (
                (site["type"] == "sample")  # sample statement
                and (
                    (return_sites is None) or (name in return_sites)
                )  # selected in return_sites list
                and (
                    (
                        (not site.get("is_observed", True)) or return_observed
                    )  # don't save observed unless requested
                    or (site.get("infer", False).get("_deterministic", False))
                )  # unless it is deterministic
                and not isinstance(
                    site.get("fn", None), poutine.subsample_messenger._Subsample
                )  # don't save plates
            )
        } 

To adapt batch sampling in the cell2fate model, I made a few changes to the function. Firstly, I removed the global sampling, which would sample everything in the last batch again. I'm not sure if it's correct to do so, as the RF/GO analysis slightly changes when I use single batch sampling with or without global variable sampling. Secondly, I defined the variables that have cell number dimensions and concatenated them in the cell number dimension after sampling. For the other variables, I simply applied element-wise averaging between batches. Here is the adapted function that I use:

```python def _posterior_samples_minibatch( self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs ): samples = dict() _, device = parse_use_gpu_arg(use_gpu) batch_size = batch_size if batch_size is not None else settings.batch_size train_dl = AnnDataLoader( self.adata_manager, shuffle=False, batch_size=batch_size ) # sample local parameters i = 0 cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target'] for tensor_dict in track( train_dl, style="tqdm", description="Sampling local variables, batch: ", ): args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: return_observed = getattr(sample_kwargs, "return_observed", False) obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=return_observed ) if len(obs_plate_sites) == 0: # if no local variables - don't sample break obs_plate_dim = list(obs_plate_sites.values())[0] sample_kwargs_obs_plate = sample_kwargs.copy() sample_kwargs_obs_plate[ "return_sites" ] = self._get_obs_plate_return_sites( sample_kwargs["return_sites"], list(obs_plate_sites.keys()) ) sample_kwargs_obs_plate["show_progress"] = False samples = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) else: samples_ = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) num_cells_in_batch = samples_['t_c'].shape[1] for k in samples.keys(): if samples_[k].ndim >1: if k in cell_specific: samples[k] = np.concatenate([samples[k], samples_[k]], axis=1) else: ratio_cells = num_cells_in_batch / batch_size samples[k] = (samples[k] * i + samples_[k] * ratio_cells) / (i+ratio_cells) i += 1 i += 1 self.module.to(device) return samples ``` <\details> This approach seems to work well, as the results with batch export and without batch export are quite similar. However, there is again a slight difference between RF/GO analyses. I believe that difference arises because, even though the seed is set to 0, a change in the sampling strategy will alter the random function's output. I also tried to sample the variables that are not cell-specific from only the last batch, but it does not seem to work well, even though it should work theoretically, as the variables that are not cell-specific are sampled from already learned distributions. I wanted to ask if my approach in line with the cell2fate assumption. Thank you very much for your help.
vitkl commented 1 year ago

Would be great to add this to scvi-tools directly. Sounds like this is solving an important issue.

AlexanderAivazidis commented 1 year ago

Great! Let's add this particular solution to cell2fate only, because it relies on specifying "cell-specific variables", which may be different for each model. You can make a branch that changes the export_posterior method of the cell2fate_DynamicalModel with your modifications and then make a pull request. After that we can indeed think about how to add a more general solution to scvi-tools directly!

vitkl commented 1 year ago

Why do you need this bit of code?

cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target']

The _get_obs_plate_sites code below should already allow detecting which variables are cell-specific:

obs_plate_sites = self._get_obs_plate_sites(
    args, kwargs, return_observed=return_observed
)

element-wise averaging between batches

This is not correct to do for global variables. This breaks the independence of samples and probably yields very narrow distributions (distribution of posterior samples -> distributions of means). Global variables are sampled using a separate step exactly because we need to get them once using one data batch (data batch is irrelevant for these variables, their distribution is the same for all batches).

In this case, posterior sampling does not use the 'return_sites' information and seems to return everything.

This is a known issue that needs to be solved. I initially thought that this what this addresses @sezginerr.

However, for the cell2fate model, local and global variables seem to the same

Is this really the case? Why is this the case? Statements with obs_plate make variables local variables https://github.com/BayraktarLab/cell2fate/blob/22917f7502fe5fb57e87a72cff704f12eee893fa/cell2fate/_cell2fate_DynamicalModel_module.py#L427 and list_obs_plate_vars method https://github.com/BayraktarLab/cell2fate/blob/22917f7502fe5fb57e87a72cff704f12eee893fa/cell2fate/_cell2fate_DynamicalModel_module.py#L337 tells scvi-tools which plate is the minibatch plate.

Does the default posterior sampling code fail completely or simply "does not use the 'return_sites' information and seems to return everything"?

sezginerr commented 1 year ago

The _get_obs_plate_sites method currently returns only the variables defined with obs_plate: and excludes deterministic ones, resulting in only t_c and detection_y_c being returned. However, the remaining deterministic variables (T_c, mu_expression, mu, data_target) also have a cell number dimension that should be concatenated in each batch. Instead of adding them manually, I can also check each variable in the first batch and create a dictionary of keys that have a shape equal to the batch size in axis=1 to create cell specific variable dictionary.

Yes, calculating the mean probably causes a narrow distribution. I checked the code again, and I believe there was a bug in my first code. Now it seems that getting globals from the last batch seems to work. However, there are still slight differences between batch and non-batch exports.

While both local and global variables are defined in cell2fate, the problem lies in the posterior sampling process, which disregards the return_sites information coming from _get_obs_plate_sites, causing duplicate returns of all variables, as you mentioned. So it is not that the variables are same but the code does not use that information.

vitkl commented 1 year ago

Are the deterministic variables wrapped into with obs_plate in the model? They should be, otherwise, they will be ignored by plate detection.

vitkl commented 1 year ago

I have some code to solve these issues once and for all. It also handles multiple plates and arbitrary orders of items within plates. My code works for direct posterior median computation - but not for sampling. It needs some minor modifications and testing. Would you be able to help?

sezginerr commented 1 year ago

No, deterministic variables are not wrapped into with obs_plate in cell2fate. See this for example:

https://github.com/BayraktarLab/cell2fate/blob/22917f7502fe5fb57e87a72cff704f12eee893fa/cell2fate/_cell2fate_DynamicalModel_module.py#L427-L429

Yes, I would like to help with the code. I believe it would be great if there is a general solution for this.

AlexanderAivazidis commented 1 year ago

Thanks for your insights on this! So for a start I will wrap all cell specific deterministic variables in a plate, so we can detect them as local variables automatically.

Next we should definitely implement Vitalii's solution to do direct posterior median computation.

Finally, I think Sezgin's solution is probably still ok for the moment, because the only time we use something other then the posterior mean is for the time variable "T_c", where we are also interested in the variance and that should be sampled accurately given it is a local variable.

AlexanderAivazidis commented 1 year ago

I wrapped cell specific deterministic variables in plates now, but this has not solved it I think, because for some reason "data_target" is not detected as a local variable. So I think it would be great if @sezginerr adds his approximate solution for now and then we come up with something more general next.

vitkl commented 1 year ago

data_target is an observed variable, right? Not much reason to track it.

vitkl commented 1 year ago

Here is my code for computing the quantiles and medians. It should work for this project with no modifications, except for dataloader options.

This code needs to be modified with sampling statements instead of median/quantiles, including correctly handling the sample dimension. I would really appreciate having the posterior sampling solution for the cell2state model project in addition to this problem.

```python import numpy as np import torch from pyro import poutine from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_use_gpu_arg from scvi.utils import track from cell2state.dataloaders._per_gene_per_cell_chr_dataloader import ( PerGeneChromatinDataSplitter, ) def expand_zeros_along_dim(tensor, size, dim): shape = np.array(tensor.shape) shape[dim] = size return np.zeros(shape) def complete_tensor_along_dim(tensor, indices, dim, value, mode="put"): shape = value.shape shape = np.ones(len(shape)) shape[dim] = len(indices) shape = shape.astype(int) indices = indices.reshape(shape) if mode == "take": return np.take_along_axis(arr=tensor, indices=indices, axis=dim) np.put_along_axis(arr=tensor, indices=indices, values=value, axis=dim) return tensor def _complete_full_tensors_using_plates( means_global, means, plate_dict, obs_plate_sites, plate_indices, plate_dim ): # complete full sized tensors with minibatch values given minibatch indices for k in means_global.keys(): # find which and how many plates contain this tensor plates = [ plate for plate in plate_dict.keys() if k in obs_plate_sites[plate].keys() ] if len(plates) == 1: # if only one plate contains this tensor, complete it using the plate indices means_global[k] = complete_tensor_along_dim( means_global[k], plate_indices[plates[0]], plate_dim[plates[0]], means[k], ) elif len(plates) == 2: # subset data to index for plate 0 and fill index for plate 1 means_global_k = complete_tensor_along_dim( means_global[k], plate_indices[plates[0]], plate_dim[plates[0]], means[k], mode="take", ) means_global_k = complete_tensor_along_dim( means_global_k, plate_indices[plates[1]], plate_dim[plates[1]], means[k], ) # fill index for plate 0 in the full data means_global[k] = complete_tensor_along_dim( means_global[k], plate_indices[plates[0]], plate_dim[plates[0]], means_global_k, ) # TODO add a test - observed variables should be identical if this code works correctly # This code works correctly but the test needs to be added eventually # np.allclose( # samples['data_chromatin'].squeeze(-1).T, # mod_reg.adata_manager.get_from_registry('X')[ # :, ~mod_reg.adata_manager.get_from_registry('gene_bool').ravel() # ].toarray() # ) else: NotImplementedError( f"Posterior sampling/mean/median/quantile not supported for variables with > 2 plates: {k} has {len(plates)}" ) return means_global class QuantileMixin: """ This mixin class provides methods for: - computing median and quantiles of the posterior distribution using both direct and amortised inference """ def _get_obs_plate_sites_v2( self, args: list, kwargs: dict, plate_name: str = None, return_observed: bool = False, return_deterministic: bool = True, ): """ Automatically guess which model sites belong to observation/minibatch plate. This function requires minibatch plate name specified in `self.module.list_obs_plate_vars["name"]`. Parameters ---------- args Arguments to the model. kwargs Keyword arguments to the model. return_observed Record samples of observed variables. Returns ------- Dictionary with keys corresponding to site names and values to plate dimension. """ if plate_name is None: plate_name = self.module.list_obs_plate_vars["name"] def try_trace(args, kwargs): try: trace_ = poutine.trace(self.module.guide).get_trace(*args, **kwargs) trace_ = poutine.trace( poutine.replay(self.module.model, trace_) ).get_trace(*args, **kwargs) except ValueError: # if sample is unsuccessful try again trace_ = try_trace(args, kwargs) return trace_ trace = try_trace(args, kwargs) # find plate dimension obs_plate = { name: { fun.name: fun for fun in site["cond_indep_stack"] if (fun.name in plate_name) or (fun.name == plate_name) } for name, site in trace.nodes.items() if ( (site["type"] == "sample") # sample statement and ( ( (not site.get("is_observed", True)) or return_observed ) # don't save observed unless requested or ( site.get("infer", False).get("_deterministic", False) and return_deterministic ) ) # unless it is deterministic and not isinstance( site.get("fn", None), poutine.subsample_messenger._Subsample ) # don't save plates ) if any(f.name == plate_name for f in site["cond_indep_stack"]) } return obs_plate def _get_dataloader( self, batch_size, gene_batch_size, data_loader_indices, dna_region_batch_mode="windows", shuffle_training=False, dl_kwargs={}, ): if self.minibatch_genes_: dl_kwargs_ = { "shuffle_training": shuffle_training, "train_size": 1.0, "batch_size": batch_size, "gene_batch_size": gene_batch_size, "cell_plate_inputs": self.module.model.cell_plate_inputs, "var_plate_inputs": self.module.model.var_plate_inputs, "square_var_plate_inputs": self.module.model.square_var_plate_inputs, "per_site_plate_inputs": self.module.model.per_site_plate_inputs, "filter_tensors": self.module.model.filter_tensors, "gene_bool": self.adata_manager.get_from_registry( registry_key="gene_bool" ).flatten(), "drop_last": False, **dl_kwargs, } if dna_region_batch_mode == "gene_batch": dl_kwargs_["gene_region_coo"] = self._get_gene_region_coo( use_all_genes=False ) elif dna_region_batch_mode == "windows": dl_kwargs_["positions"] = self.adata_manager.get_from_registry( registry_key="positions" ).flatten() dl_kwargs_["chromosomes"] = self.adata_manager.get_from_registry( registry_key="chromosome" ).flatten() train_dl = PerGeneChromatinDataSplitter( self.adata_manager, **dl_kwargs_, ) train_dl.setup() train_dl = train_dl.train_dataloader() else: train_dl = AnnDataLoader( self.adata_manager, shuffle=False, batch_size=batch_size, indices=data_loader_indices, **dl_kwargs, ) return train_dl @torch.no_grad() def _posterior_quantile_minibatch( self, q: float = 0.5, batch_size: int = 128, gene_batch_size: int = 50, use_gpu: bool = None, use_median: bool = False, return_observed: bool = True, exclude_vars: list = None, data_loader_indices=None, show_progress: bool = True, dna_region_batch_mode: str = "gene_batch", # ['gene_batch', 'windows'], n_vars_per_dna_region_batch: int = 420, ): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ _, _, device = parse_use_gpu_arg(use_gpu) self.module.eval() # Warmup run for one step - to eg cache TF effects dl = self._get_dataloader( batch_size=batch_size, gene_batch_size=gene_batch_size, data_loader_indices=data_loader_indices, dna_region_batch_mode=dna_region_batch_mode, dl_kwargs={ "n_vars_per_dna_region_batch": n_vars_per_dna_region_batch, }, ) for tensor_dict in dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) self.module.train() # to cache TF effects TODO: remove this hack as soon as possible self.module.guide(*args, **kwargs) self.module.eval() break train_dl = self._get_dataloader( batch_size=batch_size, gene_batch_size=gene_batch_size, data_loader_indices=data_loader_indices, dna_region_batch_mode=dna_region_batch_mode, dl_kwargs={ "n_vars_per_dna_region_batch": n_vars_per_dna_region_batch, }, ) i = 0 for tensor_dict in track( train_dl, style="tqdm", description=f"Computing posterior quantile {q}, data batch: ", disable=not show_progress, ): args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: minibatch_plate_names = self.module.list_obs_plate_vars["name"] plates = self.module.model.create_plates(*args, **kwargs) if not isinstance(plates, list): plates = [plates] # find plate indices & dim plate_dict = { plate.name: plate for plate in plates if ( (plate.name in minibatch_plate_names) or (plate.name == minibatch_plate_names) ) } plate_size = {name: plate.size for name, plate in plate_dict.items()} if data_loader_indices is not None: # set total plate size to the number of indices in DL not total number of observations # this option is not really used plate_size = { name: len(train_dl.indices) for name, plate in plate_dict.items() if plate.name == minibatch_plate_names } plate_dim = {name: plate.dim for name, plate in plate_dict.items()} plate_indices = { name: plate.indices.detach().cpu().numpy() for name, plate in plate_dict.items() } # find plate sites obs_plate_sites = { plate: self._get_obs_plate_sites_v2( args, kwargs, plate_name=plate, return_observed=return_observed ) for plate in plate_dict.keys() } if use_median and q == 0.5: # use median rather than quantile method def try_median(args, kwargs): try: means_ = self.module.guide.median(*args, **kwargs) except ValueError: # if sample is unsuccessful try again means_ = try_median(args, kwargs) return means_ means = try_median(args, kwargs) else: def try_quantiles(args, kwargs): try: means_ = self.module.guide.quantiles([q], *args, **kwargs) except ValueError: # if sample is unsuccessful try again means_ = try_quantiles(args, kwargs) return means_ means = try_quantiles(args, kwargs) means = { k: means[k].detach().cpu().numpy() for k in means.keys() if k not in exclude_vars } means_global = means.copy() for plate in plate_dict.keys(): # create full sized tensors according to plate size means_global = { k: ( expand_zeros_along_dim( means_global[k], plate_size[plate], plate_dim[plate] ) if k in obs_plate_sites[plate].keys() else means_global[k] ) for k in means_global.keys() } # complete full sized tensors with minibatch values given minibatch indices means_global = _complete_full_tensors_using_plates( means_global=means_global, means=means, plate_dict=plate_dict, obs_plate_sites=obs_plate_sites, plate_indices=plate_indices, plate_dim=plate_dim, ) if np.all([len(v) == 0 for v in obs_plate_sites.values()]): # if no local variables - don't sample further - return results now break else: if use_median and q == 0.5: def try_median(args, kwargs): try: means_ = self.module.guide.median(*args, **kwargs) except ValueError: # if sample is unsuccessful try again means_ = try_median(args, kwargs) return means_ means = try_median(args, kwargs) else: def try_quantiles(args, kwargs): try: means_ = self.module.guide.quantiles([q], *args, **kwargs) except ValueError: # if sample is unsuccessful try again means_ = try_quantiles(args, kwargs) return means_ means = try_quantiles(args, kwargs) means = { k: means[k].detach().cpu().numpy() for k in means.keys() if k not in exclude_vars } # find plate indices & dim plate_dict = { plate.name: plate for plate in self.module.model.create_plates(*args, **kwargs) if ( (plate.name in minibatch_plate_names) or (plate.name == minibatch_plate_names) ) } plate_indices = { name: plate.indices.detach().cpu().numpy() for name, plate in plate_dict.items() } # TODO - is this correct to call this function again? find plate sites obs_plate_sites = { plate: self._get_obs_plate_sites_v2( args, kwargs, plate_name=plate, return_observed=return_observed ) for plate in plate_dict.keys() } # complete full sized tensors with minibatch values given minibatch indices means_global = _complete_full_tensors_using_plates( means_global=means_global, means=means, plate_dict=plate_dict, obs_plate_sites=obs_plate_sites, plate_indices=plate_indices, plate_dim=plate_dim, ) i += 1 self.module.to(device) return means_global ``` <\details>