cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
11 stars 2 forks source link

scvi - enable batch embeddings, sampling, and freezing of batch bias computations #169

Closed sjfleming closed 4 months ago

sjfleming commented 4 months ago

Closes #167

This is a rather large refactor to change the way that batch-bias values are computed. It is closer to the way scvi-tools does things, but instead of concatenating a one-hot batch label directly to the input data (for a given layer), this implementation uses a separate neural network (need not be linear, but can be) to compute batch-biases. This allows things to be much more flexible. For instance, the batch input can be one-hot, or some other representation of batch (sampled from a batch embedding space for example), and it allows us to easily freeze the network that computes the batch-bias values (if we don't want to train them, but rather use ones from a previous checkpoint).

sjfleming commented 4 months ago

This does not yet address #144, but hopefully it will make it easier

sjfleming commented 4 months ago

Currently, there is a new notion of a "batch embedding" which can be one-hot, or can be learned. If it is learned, it can be a point estimate or it can be sampled from a posterior distribution. If sampled, you can optionally add a KL divergence to the loss function.