open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.48k stars 231 forks source link

How to get the student model checkpoint after finishing knowledge distillation? #76

Open ChenDirk opened 2 years ago

ChenDirk commented 2 years ago

Hi, I find that after finishing knowledge distillation, the checkpoint file is very huge I get, so I think the checkpoint file include the student model and the teacher model. But when we deploy the model, we only need the student model, and I don't know how to split the checkpoint file with mmrazor.

pppppM commented 2 years ago

It would help if you wrote some convert scripts according to the codebase you use; you can refer to https://github.com/open-mmlab/mmclassification/tree/master/tools/convert_models

maxrumi commented 2 years ago

try this

import torch
import mmcv
import sys
from mmrazor.models.builder import build_algorithm
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcls.models import build_classifier

def split_student_model(cls_cfg_path, cls_model_path, device='cuda', save_path=None):
    """
    :param: cls_cfg_path: your normal classifier config file path which is not disitilation cfg path
    :param: cls_model_path: your distilation checkpoint path
    :param: save_path: student model save path
    """
    cfg = mmcv.Config.fromfile(cls_cfg_path)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.data.test.test_mode = True
    model = build_classifier(cfg.model)
    model_ckpt = torch.load(cls_model_path)
    pretrained_dict = model_ckpt['state_dict']
    model_dict = model.state_dict()
    new_dict = {k.replace('architecture.model.', ''): v for k, v in pretrained_dict.items() if k.replace('architecture.model.', '') in model_dict.keys()}
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    torch.save({'state_dict': model.state_dict(), 'meta': model_ckpt['meta'],
                'optimizer': model_ckpt['optimizer']}, save_path)

after generate new checkpoint, use config files such as normal resnet50, not distilation config file; and use from mmcls.models import build_classifier to build new model and load new checkpoint.

pppppM commented 2 years ago

Would you like to make a PR? @maxrumi

ChenDirk commented 2 years ago

try this

from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmrazor.models.builder import build_algorithm

def split_student_model(cls_cfg_path, cls_model_path, device='cuda', save_path):
    """
    :param: cls_cfg_path: your distilation config path
    :param: cls_model_path: your distilation checkpoint path
    :param: save_path: student model save path
    """
    cfg = mmcv.Config.fromfile(cls_cfg_path)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.data.test.test_mode = True
    model = build_algorithm(cfg.algorithm)
    model_ckpt = torch.load(cls_model_path)
    print(model_ckpt.keys())
    pretrained_dict = model_ckpt['state_dict']
    model_dict = model.state_dict()
    new_dict = {k.replace('architecture.model.', ''): v for k, v in pretrained_dict.items() if k.replace('architecture.model.', '') in model_dict.keys()}
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    torch.save({'state_dict': model.state_dict(), 'meta': model_ckpt['meta'],
                'optimizer': model_ckpt['optimizer']}, save_path)

after generate new checkpoint, use config files such as normal resnet50, not distilation config file; and use from mmcls.models import build_classifier to build new model and load new checkpoint.

Hi, I try your code, but the size of new checkpoint file is close to the original one, so I think it still include the teacher model and the student model.

maxrumi commented 2 years ago

try this

from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmrazor.models.builder import build_algorithm

def split_student_model(cls_cfg_path, cls_model_path, device='cuda', save_path):
    """
    :param: cls_cfg_path: your distilation config path
    :param: cls_model_path: your distilation checkpoint path
    :param: save_path: student model save path
    """
    cfg = mmcv.Config.fromfile(cls_cfg_path)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.data.test.test_mode = True
    model = build_algorithm(cfg.algorithm)
    model_ckpt = torch.load(cls_model_path)
    print(model_ckpt.keys())
    pretrained_dict = model_ckpt['state_dict']
    model_dict = model.state_dict()
    new_dict = {k.replace('architecture.model.', ''): v for k, v in pretrained_dict.items() if k.replace('architecture.model.', '') in model_dict.keys()}
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    torch.save({'state_dict': model.state_dict(), 'meta': model_ckpt['meta'],
                'optimizer': model_ckpt['optimizer']}, save_path)

after generate new checkpoint, use config files such as normal resnet50, not distilation config file; and use from mmcls.models import build_classifier to build new model and load new checkpoint.

Hi, I try your code, but the size of new checkpoint file is close to the original one, so I think it still include the teacher model and the student model.

sorry, I didn't check my code last time, please try this again, this code will solve your problem:

import torch
import mmcv
import sys
from mmrazor.models.builder import build_algorithm
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcls.models import build_classifier

def split_student_model(cls_cfg_path, cls_model_path, device='cuda', save_path=None):
    """
    :param: cls_cfg_path: your normal classifier config file path which is not disitilation cfg path
    :param: cls_model_path: your distilation checkpoint path
    :param: save_path: student model save path
    """
    cfg = mmcv.Config.fromfile(cls_cfg_path)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.data.test.test_mode = True
    model = build_classifier(cfg.model)
    model_ckpt = torch.load(cls_model_path)
    pretrained_dict = model_ckpt['state_dict']
    model_dict = model.state_dict()
    new_dict = {k.replace('architecture.model.', ''): v for k, v in pretrained_dict.items() if k.replace('architecture.model.', '') in model_dict.keys()}
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    torch.save({'state_dict': model.state_dict(), 'meta': model_ckpt['meta'],
                'optimizer': model_ckpt['optimizer']}, save_path)
maxrumi commented 2 years ago

Would you like to make a PR? @maxrumi

ok, I will make a PR.

pppppM commented 2 years ago

Great! @maxrumi

ChenDirk commented 2 years ago

Hi, Thank you for your code. I try your method in classification and segmentation model, both of them can split the student model successfully . BTW, I suggest the function can add one parameter to distinguish the mmcls, mmdet and the mmseg model, and remove the redundant parameter(device), and it will be a very cool tool in mmrazor. Thank you again.

maxrumi commented 2 years ago

Hi, Thank you for your code. I try your method in classification and segmentation model, both of them can split the student model successfully . BTW, I suggest the function can add one parameter to distinguish the mmcls, mmdet and the mmseg model, and remove the redundant parameter(device), and it will be a very cool tool in mmrazor. Thank you again.

You're welcome. I'll take your advice.

tanghy2016 commented 2 years ago

@maxrumi How to deal with object detection? I tried changing

model = build_classifier(cfg.model)

to

model = build_detector(cfg.model)

but when loading the model, the weights don't match the model.

HePengguang commented 1 year ago

Would you like to make a PR? @maxrumi

ok, I will make a PR.

Hi, how to extract the student chpt of mmseg model, we need your help, thank you so much. Please!