wofmanaf / SA-Net

Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks
249 stars 42 forks source link

预训练模型格式 #7

Closed yzk-lab closed 3 years ago

yzk-lab commented 3 years ago

这个预训练格式是什么?我解压下来看不懂,不应该是.pth文件吗?谢谢指教

yzk-lab commented 3 years ago

那么这个.pth.tar文件我怎么加载呢?麻烦告知,谢谢。

yzk-lab commented 3 years ago

![Uploading image.png…]()

wofmanaf commented 3 years ago

不需要解压,采用和.pth同样的格式加载就可以了,可以参考main.py。这里提供一个标准的加载函数


import os
import torch
from collections import OrderedDict

def load_state_dict(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = 'state_dict' if state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint print("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: print("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError()

def load_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) model.load_state_dict(state_dict, strict=strict)

yzk-lab commented 3 years ago

谢谢回复,请问你有.pth模型吗?我想在mmdetection检测任务中load这个checkpoint 我尝试改了之后好像还是没法读取这个文件。

wofmanaf commented 3 years ago

这个不是预训练模型格式的问题。mmdetection的backbone需要简单改写一下模型就可以了,首先要register_module, 去掉avg_pool和之后的fc操作,将四个layers的output给以tuple的形式作为输出即可。 可以参考https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py的格式进行修改。