AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.44k stars 1.05k forks source link

Recommended way to load the model after training? #67

Open peacej opened 2 years ago

peacej commented 2 years ago

For example I guess this is one way?

from experiment import VAEXperiment
config = yaml.safe_load(open('configs/vae.yaml'))
ckpt = torch.load('logs/VanillaVAE/version_1/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])

Then one can access the model via experiment.model

It took me a while to figure this out. Maybe add such instructions to the README?

tudorjnu commented 1 year ago

Hello! What model did you pass to the experiment?

tudorjnu commented 1 year ago

Nevermind, it works with:

from experiment import VAEXperiment
import yaml
import torch 
from models import *

config = yaml.safe_load(open('./configs/bbvae.yaml'))
model = vae_models[config['model_params']['name']](**config['model_params'])
ckpt = torch.load('./logs/BetaVAE/version_0/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])

where I used the BetaVAE model.