cnellington / Contextualized

An SKLearn-style toolbox for estimating and analyzing models, distributions, and functions with context-specific parameters.
http://contextualized.ml/
GNU General Public License v3.0
65 stars 9 forks source link

Checkpointing with lightning module subclasses does not save subclass hyperparameters #119

Open cnellington opened 2 years ago

cnellington commented 2 years ago

To reproduce:

  contextualized_model = ContextualizedRegression(
      C.shape[-1], 
      X.shape[-1], 
      Y.shape[-1],
      **subtype_kwargs
  )
  train_dataset = contextualized_model.dataloader(C, X, Y, batch_size=10)
  checkpoint_callback = ModelCheckpoint(
      monitor="val_loss", 
  )
  trainer = RegressionTrainer(
      max_epochs=10, 
      callbacks=[checkpoint_callback]
  )
  trainer.fit(contextualized_model, train_dataset)
  contextualized_model = ContextualizedRegression.load_from_checkpoint(checkpoint_callback.best_model_path)

Workaround with torch state_dict:

contextualized_model = # instantiate contextualized_model with same kwargs
best_checkpoint = torch.load(checkpoint_callback.best_model_path)
contextualized_model.load_state_dict(best_checkpoint['state_dict'])