cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
22 stars 3 forks source link

scvi: ensure categorical covariates and continuous covariates work #205

Open sjfleming opened 4 months ago

sjfleming commented 4 months ago

We have never tested this, but I think we need some code changes to make scvi work with categorical covariates and continuous covariates. What's going on, in my understanding, is this:

In the scvi model we have a batch latent variable $s_n$

This $s_n$, functionally, is used as an input that the encoder and decoder are conditioned on.

In the simple case, $s_n$ is a one-hot encoding of a single categorical "batch" variable. We have this implemented.

To get more complex, the scvi codebase allows you to (and I think the published CZI model does) use additional categorical (or continuous) covariates. They all become part of $s_n$, and are used to condition the encoder and decoder, but now $s_n$ no longer needs to be one-hot.

For example, let's say we have 4 datasets ("dataset_id") and 2 suspension types ("suspension_type"). Like this

cell         dataset_id           suspension_type
0            A                    whole_cell
1            A                    nucleus
2            B                    nucleus
3            B                    whole_cell
4            A                    whole_cell
5            B                    nucleus
...

then we could set up a model with batch="dataset_id" and categorical_covariates=["suspension_type"] and the $s_n$ values might look like this

cell       s_nb
0          [1, 0, 1, 0]
1          [1, 0, 0, 1]
2          [0, 1, 0, 1]
3          [0, 1, 1, 0]
4          [1, 0, 1, 0]
5          [0, 1, 0, 1]
...

where the first two entries are one-hot for "dataset_id" and the last two entries are one-hot for "suspension_type".

The trick here is how to handle the embedding of batch for the case where we want batch to be embedded. Would we embed only the "batch", i.e. "dataset_id" and let the rest still be one-hot? Would we embed that whole s_nb thing jointly? Would we embed each part separately, i.e. one embedding for "dataset_id" and another separate embedding for "suspension_type"?