irfanICMLL / structure_knowledge_distillation

The official code for the paper 'Structured Knowledge Distillation for Semantic Segmentation'. (CVPR 2019 ORAL) and extension to other tasks.
BSD 2-Clause "Simplified" License
708 stars 103 forks source link

about the code #54

Open 15757170756 opened 4 years ago

15757170756 commented 4 years ago
import argparse
import logging
import os
import pdb
from torch.autograd import Variable
import os.path as osp
import torch
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import numpy as np
import resource
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils.utils import *
import torch.backends.cudnn as cudnn

from utils.criterion import CriterionDSN, CriterionOhemDSN, CriterionPixelWise, \
    CriterionAdv, CriterionAdvForG, CriterionAdditionalGP, CriterionPairWiseforWholeFeatAfterPool
import utils.parallel as parallel_old
from networks.pspnet_combine import Res_pspnet, BasicBlock, Bottleneck
from networks.sagan_models import Discriminator
from networks.evaluate import evaluate_main

torch_ver = torch.__version__[:3]

class NetModel():
    def name(self):
        return 'kd_seg'

    def DataParallelModelProcess(self, model, ParallelModelType = 1, is_eval = 'train', device = 'cuda'):
        if ParallelModelType == 1:
            parallel_model = DataParallelModel(model)
        elif ParallelModelType == 2:
            parallel_model = parallel_old.DataParallelModel(model)
        else:
            raise ValueError('ParallelModelType should be 1 or 2')
        if is_eval == 'eval':
            parallel_model.eval()
        elif is_eval == 'train':
            parallel_model.train()
        else:
            raise ValueError('is_eval should be eval or train')
        parallel_model.float()
        parallel_model.to(device)
        return parallel_model
 if ParallelModelType == 1:
            parallel_model = DataParallelModel(model)
        elif ParallelModelType == 2:
            parallel_model = parallel_old.DataParallelModel(model)

I can't find the function "DataParallelModel",where is it?

irfanICMLL commented 3 years ago

Here it is. https://github.com/irfanICMLL/structure_knowledge_distillation/blob/ce208e1e5ba9177ecfc42519a2c64148d396fb71/utils/parallel.py

15757170756 commented 3 years ago

but this is same with "parallel_old.DataParallelModel" from "import utils.parallel as parallel_old"