cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
22 stars 3 forks source link

Implement "on_predict_epoch_end" in CellariumModule #226

Open sjfleming opened 3 months ago

sjfleming commented 3 months ago

We have a use case for the "on_predict_epoch_end" callback (https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#on-predict-epoch-end) in NMF. We would like to save the learned gene factor matrix as output. This is a matrix with shape [programs, genes], rather than being a per-cell output which could be handled by a PredictionWriter.

We would like to add

    def on_predict_epoch_end(self) -> None:
        """
        Calls the ``on_predict_epoch_end`` method on the :attr:`model` attribute.
        If the :attr:`model` attribute has ``on_predict_epoch_end`` method defined, then
        ``on_predict_epoch_end`` must be called at the end of every epoch.
        """
        on_predict_epoch_end = getattr(self.model, "on_predict_epoch_end", None)
        if callable(on_predict_epoch_end):
            on_predict_epoch_end(self.trainer)

to CellariumModule.

Sound reasonable @ordabayevy ?

ordabayevy commented 3 months ago

I don't know how NMF works, so trying to understand. Is this called at the end of trainer.predict using an already trained model?

sjfleming commented 3 months ago

As we discussed in the meantime, we were thinking of trying to do something with "predict" that's not really prediction in the usual sense.

Maybe we will create a new subcommand to do what we want instead. I'm not sure yet.

(But... I think people might still want access to "on_predict_epoch_end" in the future, even if we don't end up using it here. It would be a way to save anything additional that is computed by "predict", even if it does not follow the model of what's expected by a PredictionWriter where you're saving the output of the "predict" function per batch.)

ordabayevy commented 3 months ago

But... I think people might still want access to "on_predict_epoch_end" in the future, even if we don't end up using it here.

Yes, I think it is fine to add it. I will also mention here that on_epoch_end and on_batch_end hooks in CellariumModels should be renamed to on_train_epoch_end and on_train_batch_end, respectively. This will help to avoid any ambiguity.

sjfleming commented 3 months ago

Oh good point. I very much agree with you!

I was wondering if there was some slick way to implement all of these lightning callbacks in CellariumModule without a ton of copy and paste, so that they're all just redirects to the CellariumModel instance's methods...

sjfleming commented 3 months ago

Kind of like this, but maybe there is something a little neater https://stackoverflow.com/questions/58232546/override-all-class-methods-at-once-to-do-the-same-thing