cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
22 stars 3 forks source link

YAML improvement - do not specify model if using a checkpoint #260

Open sjfleming opened 4 weeks ago

sjfleming commented 4 weeks ago

@ordabayevy we have been noticing that (we think) we need to specify the full model architecture in the yaml config file (and get it correct) even if we are using ckpt_path. First of all, is that correct?

If it is correct, it is very tough for users and quite error-prone.

Can we figure out a way around this? The ideal would be to specify only the ckpt_path and be able to do something like run a predict step, without having to specify anything at all about the model (since it's coming from the checkpoint).

ordabayevy commented 4 weeks ago

That's correct. CLI uses the config file to instantiate module and datamodule and then uses the checkpoint to load the state. Yes, it is annoying kinda. Although I don't know why it was designed this way by PyTorch Lightning. (Maybe because it follows how you would do it in Python - define module and datamodule first and then pass it to the trainer fit - trainer.fit(module, datamodule, ckpt_path))

Note that CellariumModule.load_from_checkpoint works differently - it instantiates objects from the hyperparameters stored in the checkpoint and then loads the state.

It potentially can be done since we store all the hyperparameters in the checkpoint. I'll think about how this can be accomplished and how much code changes it requires.

sjfleming commented 3 weeks ago

Awesome thanks! Yeah there's no rush, but I think it could be a huge usability improvement. I'll think about it too.

sjfleming commented 2 weeks ago

Hmm, I did notice something ...

You can use a checkpointed model as a "transform" without specifying anything about the model configuration ... is there a way to leverage this? Maybe not

sjfleming commented 2 weeks ago

LightningCLI says in its docstring on the __init__ argument model_class:

    model_class: An optional :class:`~lightning.pytorch.core.LightningModule` class to train on or a
        callable which returns a :class:`~lightning.pytorch.core.LightningModule` instance when
        called. If ``None``, you can pass a registered model with ``--model=MyModel``.

I wonder... can we pass in a callable which returns an instance... i.e. something like

def my_callable():
    return CellariumModule.load_from_checkpoint(ckpt_path) 

?