KevinMusgrave / pytorch-adapt

Domain adaptation made easy. Fully featured, modular, and customizable.
https://kevinmusgrave.github.io/pytorch-adapt/
MIT License
353 stars 15 forks source link

Saving and Restoring a Trained Model #92

Closed r0f1 closed 1 year ago

r0f1 commented 1 year ago

Hi, this is roughly the code that I am using for training my models:

models = Models({"G": G, "C": C, "D": D})
adapter = DANN(models=models)
validator = IMValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")

L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, 
                     max_epochs=1,
                     default_root_dir="saved_models",
                     enable_checkpointing=True)
trainer.fit(L_adapter, train_loader, list(dataloaders.values()))

which causes the latest model to be saved under saved_models/lightning_logs/version_0/checkpoints/epoch=1-step=2832.ckpt.

Question 1): Is it possible to restore all three models, G, C and D from this checkpoint, and if yes how? I know that Lightning provides the function load_from_checkpoint() but I can't get it to work. Question 2) If it is not possible to restore these models from the Lightning checkpoint, should I instead just manually save the state_dicts of G, C and D and then manually restore these, or is there a more elegant way?

KevinMusgrave commented 1 year ago

I think you have to pass in the __init__ args again when you load:

# pass in adapter and validator
Lightning.load_from_checkpoint(path, adapter=adapter, validator=validator)
r0f1 commented 1 year ago

Thanks I'll give it a try.