Open sjfleming opened 5 months ago
Add the appropriate option to the config file to specify whether we are learning batch bias vectors during training, or whether we are providing them and holding them fixed.
(Can we infer this based on the presence of a checkpointed model we are using to initialize the batch bias vectors?)
See this example of how to use a checkpoint in a config file: https://github.com/cellarium-ai/cellarium-ml/blob/module-tutorial/configs/lr_config.yaml#L47-L52
This will be ready to be worked on only after #156 gets merged, which will be soon.
@sjfleming Just to make sure, this would be the affected part: https://github.com/cellarium-ai/cellarium-ml/blob/modified_scvi/cellarium/ml/models/common/nn.py#L29C1-L47C34
Right? Instead of computing them , use the ones given by the checkpoint
Yeah, I think LinearWithBatch
might be a good place to intervene to make this happen.
There are many ways, but maybe we can come up with a simple one... what about something like:
LinearWithBatch.__init__()
, like maybe fix_batch_effect: bool
or somethingself.fix_batch_effect
in, maybe, self.compute_bias()
? It could be something like return self.cached_biases
if self.fix_batch_effect
is false, and if it's true then return self.cached_biases.detach()
?I think we can let the checkpoint loader be responsible for making sure that self.bias_mean_layer
for LinearWithBatch
gets loaded correctly, so it should already have the pre-trained batch-biases. I think the idea here is just to be able to turn off further training of the batch-biases.
We want the ability to use initialize the model with pre-computed batch bias vectors and use those (without updating them) throughout training.
We should set things up so that those batch bias vectors can be pulled from a checkpoint of a previously-run scVI model. We should have checks in place to make sure that the checkpointed run is "compatible" with this run (i.e. same batches are present).