Closed thelostscout closed 11 months ago
This is not a bug, this is just how lightning handles hparams saved by self.save_hyperparameters()
.
What would be the suggested way to handle the hparams in this case? Overwrite the load_from_checkpoint
function?
I am going to answer this question with another question: Why do you need hparams to be an instance of the HParams
class? When you use HParams
, you generally want to answer 3[^1] questions:
The purpose of the HParams
class is to check these at instantiation time, i.e. before you train your module. This avoids problems like you finding out you forgot to specify some parameter after your model only trained for a day instead of 2 weeks on a cluster.
After instantiation, HParams
job is done, and it can now serve as simply a container with no additional functionality. When you load from a checkpoint, you are already guaranteed that the hparams within it were valid when you initially trained, so we do not need to answer the above questions again.
[^1]: Technically, there is also a fourth one, which is "Are there no extra parameters specified?" This avoids you leaving unused parameters in your code when refactoring, or accidentally invoking the incorrect class with all defaults.
Well it breaks at least the use of AttributeDict features (not to big of a problem, but inconvinient) and maybe other custom features included in the hparams class (not sure if this is intended anyway though).
Another solution might be to do the following:
class MyLTModel(lt.Trainable)
hparams: lt.HParams
def __init__(hparams, **kwargs):
if not isinstance(hparams, lt.HParams):
hparams = lt.HParams(**hparams)
super().__init__(hparams, **kwargs)
print(type(hparams))
Well it breaks at least the use of AttributeDict features (not to big of a problem, but inconvinient) and maybe other custom features included in the hparams class (not sure if this is intended anyway though).
Turns out if you call self.hparams
instead of hparams
in your __init__
, everything is fine :man_facepalming:
Expected output:
<class lt.Hparams>
Actual output:dict