Open sjfleming opened 4 weeks ago
That's correct. CLI uses the config file to instantiate module and datamodule and then uses the checkpoint to load the state. Yes, it is annoying kinda. Although I don't know why it was designed this way by PyTorch Lightning. (Maybe because it follows how you would do it in Python - define module and datamodule first and then pass it to the trainer fit - trainer.fit(module, datamodule, ckpt_path)
)
Note that CellariumModule.load_from_checkpoint
works differently - it instantiates objects from the hyperparameters stored in the checkpoint and then loads the state.
It potentially can be done since we store all the hyperparameters in the checkpoint. I'll think about how this can be accomplished and how much code changes it requires.
Awesome thanks! Yeah there's no rush, but I think it could be a huge usability improvement. I'll think about it too.
Hmm, I did notice something ...
You can use a checkpointed model as a "transform" without specifying anything about the model configuration ... is there a way to leverage this? Maybe not
LightningCLI
says in its docstring on the __init__
argument model_class
:
model_class: An optional :class:`~lightning.pytorch.core.LightningModule` class to train on or a
callable which returns a :class:`~lightning.pytorch.core.LightningModule` instance when
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
I wonder... can we pass in a callable which returns an instance... i.e. something like
def my_callable():
return CellariumModule.load_from_checkpoint(ckpt_path)
?
@ordabayevy we have been noticing that (we think) we need to specify the full
model
architecture in the yaml config file (and get it correct) even if we are usingckpt_path
. First of all, is that correct?If it is correct, it is very tough for users and quite error-prone.
Can we figure out a way around this? The ideal would be to specify only the
ckpt_path
and be able to do something like run apredict
step, without having to specify anything at all about the model (since it's coming from the checkpoint).