Closed yzk-lab closed 3 years ago
那么这个.pth.tar文件我怎么加载呢?麻烦告知,谢谢。
![Uploading image.png…]()
不需要解压,采用和.pth同样的格式加载就可以了,可以参考main.py。这里提供一个标准的加载函数
import os import torch from collections import OrderedDict
def load_state_dict(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = 'state_dict' if state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint print("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: print("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) model.load_state_dict(state_dict, strict=strict)
谢谢回复,请问你有.pth模型吗?我想在mmdetection检测任务中load这个checkpoint 我尝试改了之后好像还是没法读取这个文件。
这个不是预训练模型格式的问题。mmdetection的backbone需要简单改写一下模型就可以了,首先要register_module, 去掉avg_pool和之后的fc操作,将四个layers的output给以tuple的形式作为输出即可。 可以参考https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py的格式进行修改。
这个预训练格式是什么?我解压下来看不懂,不应该是.pth文件吗?谢谢指教