Open sezginerr opened 1 year ago
Would be great to add this to scvi-tools directly. Sounds like this is solving an important issue.
Great! Let's add this particular solution to cell2fate only, because it relies on specifying "cell-specific variables", which may be different for each model. You can make a branch that changes the export_posterior method of the cell2fate_DynamicalModel with your modifications and then make a pull request. After that we can indeed think about how to add a more general solution to scvi-tools directly!
Why do you need this bit of code?
cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target']
The _get_obs_plate_sites
code below should already allow detecting which variables are cell-specific:
obs_plate_sites = self._get_obs_plate_sites(
args, kwargs, return_observed=return_observed
)
element-wise averaging between batches
This is not correct to do for global variables. This breaks the independence of samples and probably yields very narrow distributions (distribution of posterior samples -> distributions of means). Global variables are sampled using a separate step exactly because we need to get them once using one data batch (data batch is irrelevant for these variables, their distribution is the same for all batches).
In this case, posterior sampling does not use the 'return_sites' information and seems to return everything.
This is a known issue that needs to be solved. I initially thought that this what this addresses @sezginerr.
However, for the cell2fate model, local and global variables seem to the same
Is this really the case? Why is this the case? Statements with obs_plate
make variables local variables https://github.com/BayraktarLab/cell2fate/blob/22917f7502fe5fb57e87a72cff704f12eee893fa/cell2fate/_cell2fate_DynamicalModel_module.py#L427 and list_obs_plate_vars
method https://github.com/BayraktarLab/cell2fate/blob/22917f7502fe5fb57e87a72cff704f12eee893fa/cell2fate/_cell2fate_DynamicalModel_module.py#L337 tells scvi-tools which plate is the minibatch plate.
Does the default posterior sampling code fail completely or simply "does not use the 'return_sites' information and seems to return everything"?
The _get_obs_plate_sites
method currently returns only the variables defined with obs_plate:
and excludes deterministic ones, resulting in only t_c
and detection_y_c
being returned. However, the remaining deterministic variables (T_c
, mu_expression
, mu
, data_target
) also have a cell number dimension that should be concatenated in each batch. Instead of adding them manually, I can also check each variable in the first batch and create a dictionary of keys that have a shape equal to the batch size in axis=1 to create cell specific variable dictionary.
Yes, calculating the mean probably causes a narrow distribution. I checked the code again, and I believe there was a bug in my first code. Now it seems that getting globals from the last batch seems to work. However, there are still slight differences between batch and non-batch exports.
While both local and global variables are defined in cell2fate, the problem lies in the posterior sampling process, which disregards the return_sites
information coming from _get_obs_plate_sites
, causing duplicate returns of all variables, as you mentioned. So it is not that the variables are same but the code does not use that information.
Are the deterministic variables wrapped into with obs_plate
in the model? They should be, otherwise, they will be ignored by plate detection.
I have some code to solve these issues once and for all. It also handles multiple plates and arbitrary orders of items within plates. My code works for direct posterior median computation - but not for sampling. It needs some minor modifications and testing. Would you be able to help?
No, deterministic variables are not wrapped into with obs_plate
in cell2fate. See this for example:
Yes, I would like to help with the code. I believe it would be great if there is a general solution for this.
Thanks for your insights on this! So for a start I will wrap all cell specific deterministic variables in a plate, so we can detect them as local variables automatically.
Next we should definitely implement Vitalii's solution to do direct posterior median computation.
Finally, I think Sezgin's solution is probably still ok for the moment, because the only time we use something other then the posterior mean is for the time variable "T_c", where we are also interested in the variance and that should be sampled accurately given it is a local variable.
I wrapped cell specific deterministic variables in plates now, but this has not solved it I think, because for some reason "data_target" is not detected as a local variable. So I think it would be great if @sezginerr adds his approximate solution for now and then we come up with something more general next.
data_target
is an observed variable, right? Not much reason to track it.
Here is my code for computing the quantiles and medians. It should work for this project with no modifications, except for dataloader options.
This code needs to be modified with sampling statements instead of median/quantiles, including correctly handling the sample dimension. I would really appreciate having the posterior sampling solution for the cell2state model project in addition to this problem.
Hello,
I have been trying to export posteriors batch by batch with a defined batch size, rather than exporting them in a single batch. This is achieved by concatenating local variables exported from batches along the cell dimension in the scvi library. I believe the scvi library assumes that the variables defined with 'obs_plate' are the local variables, while the global variables (any other variable) are sampled from the last batch. However, for the cell2fate model, local and global variables seem to the same, as the guide method in the cell2fate model is an instance of 'poutine.messenger.Messenger'. In this case, posterior sampling does not use the 'return_sites' information and seems to return everything. Here is the function that returns one posterior sampling in the scvi library:
To adapt batch sampling in the cell2fate model, I made a few changes to the function. Firstly, I removed the global sampling, which would sample everything in the last batch again. I'm not sure if it's correct to do so, as the RF/GO analysis slightly changes when I use single batch sampling with or without global variable sampling. Secondly, I defined the variables that have cell number dimensions and concatenated them in the cell number dimension after sampling. For the other variables, I simply applied element-wise averaging between batches. Here is the adapted function that I use: