advimman / lama

🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022
https://advimman.github.io/lama-project/
Apache License 2.0
8.12k stars 861 forks source link

A simple ckpt to pt model convertor #306

Open jiaxue-ai opened 7 months ago

jiaxue-ai commented 7 months ago

Hi,

I wrote a simple ckpt to pt model convertor, in case anyone needs it


import os
import yaml
from saicinpainting.training.trainers import load_checkpoint
from omegaconf import OmegaConf

lama_model_path = '/LaMa_models/lama-places/lama-fourier/'

train_config_path = os.path.join(lama_model_path, 'config.yaml')
with open(train_config_path, 'r') as f:
    train_config = OmegaConf.create(yaml.safe_load(f))

train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'

checkpoint_path = os.path.join(lama_model_path,
                                'models',
                                'best.ckpt')
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.freeze()
with torch.no_grad():
    typical_input = torch.zeros([1, 4, 512, 512])
    # print(model.generator(typical_input).shape)
    traced_cell = torch.jit.trace(model.generator, (typical_input))
torch.jit.save(traced_cell, os.path.join(lama_model_path, 'lama-model-best.pt'))