LarsKue / lightning-trainable

A default trainable module for pytorch lightning.
MIT License
10 stars 1 forks source link

load_from_checkpoint loads hparams as dict instead of class HParams #18

Closed thelostscout closed 11 months ago

thelostscout commented 11 months ago
import lightning_trainable as lt

class MyLTModel(lt.Trainable)
    hparams: lt.HParams

    def __init__(hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        print(type(hparams))

model = MyLTModel.load_from_checkpoint("path/to/checkpoint.ckpt")

Expected output: <class lt.Hparams> Actual output: dict

LarsKue commented 11 months ago

This is not a bug, this is just how lightning handles hparams saved by self.save_hyperparameters().

thelostscout commented 11 months ago

What would be the suggested way to handle the hparams in this case? Overwrite the load_from_checkpoint function?

LarsKue commented 11 months ago

I am going to answer this question with another question: Why do you need hparams to be an instance of the HParams class? When you use HParams, you generally want to answer 3[^1] questions:

  1. Are all required parameters specified?
  2. Are all parameters of a valid type?
  3. Are all parameters within valid value ranges?

The purpose of the HParams class is to check these at instantiation time, i.e. before you train your module. This avoids problems like you finding out you forgot to specify some parameter after your model only trained for a day instead of 2 weeks on a cluster.

After instantiation, HParams job is done, and it can now serve as simply a container with no additional functionality. When you load from a checkpoint, you are already guaranteed that the hparams within it were valid when you initially trained, so we do not need to answer the above questions again.

[^1]: Technically, there is also a fourth one, which is "Are there no extra parameters specified?" This avoids you leaving unused parameters in your code when refactoring, or accidentally invoking the incorrect class with all defaults.

thelostscout commented 11 months ago

Well it breaks at least the use of AttributeDict features (not to big of a problem, but inconvinient) and maybe other custom features included in the hparams class (not sure if this is intended anyway though).

fdraxler commented 11 months ago

Another solution might be to do the following:

class MyLTModel(lt.Trainable)
    hparams: lt.HParams

    def __init__(hparams, **kwargs):
        if not isinstance(hparams, lt.HParams):
             hparams = lt.HParams(**hparams)
        super().__init__(hparams, **kwargs)
        print(type(hparams))
thelostscout commented 11 months ago

Well it breaks at least the use of AttributeDict features (not to big of a problem, but inconvinient) and maybe other custom features included in the hparams class (not sure if this is intended anyway though).

Turns out if you call self.hparams instead of hparams in your __init__, everything is fine :man_facepalming: