ZHKKKe / Harmonizer

High-Resolution Image/Video Harmonization [ECCV 2022]
321 stars 34 forks source link

training result #9

Open xyxiaoAk opened 10 months ago

xyxiaoAk commented 10 months ago

Hi, I used the training code you provided for training, but didn't find the result of the training which is a similar file like the harmonizer.pth you provided. May I know where this file is stored after training? Or what part of the code needs to be changed to get this file?

xyxiaoAk commented 10 months ago

And I transform the final checkpoint_60.ckpt to checkpoint_60.pth only using the "model" parameter in checkpoint_60.ckpt,but it reports errors when testing :" RuntimeError: Error(s) in loading state_dict for Harmonizer: Missing key(s) in state_dict: "backbone._blocks.0._depthwise_conv.weight",..."

Thanks ahead for your help! And My code is as follows

import torch

# 指定 checkpoint 文件的路径
checkpoint_path = 'checkpoint_60.ckpt'

checkpoint = torch.load(checkpoint_path)
# 提取模型的参数
model_state_dict = checkpoint['model']

torch.save(model_state_dict, "checkpoint_60.pth")
wangyuze18 commented 2 months ago

just convert it to proper format like this

state_dict = torch.load('checkpoint_60.ckpt')['model']
new_state_dict = {}
for key in state_dict:
    # 去掉 module.model 前缀
    if key.startswith('module.model'):
        new_key = key.replace('module.model.', '')  # 注意这里要加上点以确保路径正确
        new_state_dict[new_key] = state_dict[key]
    else:
        new_state_dict[key] = state_dict[key]

model.load_state_dict(new_state_dict)