yikaiw / RS-Nets

[ECCV 2020] Code release for "Resolution Switchable Networks for Runtime Efficient Image Recognition"
https://arxiv.org/pdf/2007.09558.pdf
MIT License
40 stars 8 forks source link

How many and which GPUs are needed for training #1

Open Liuyang829 opened 3 years ago

Liuyang829 commented 3 years ago

Thanks for your great work! But when I want to reproduce your code, I meet some troubles on CUDA out of memory. I'm very curious about what GPU and how many GPUs your experiment was implemented on when traning on ResNet50 and ResNet18

yikaiw commented 3 years ago

Thanks for your recognition. We use 8 V100 GPUs for training ResNet50 and ResNet18. Actually, 4 GTX1080 GPUs are enough for ResNet18.

Liuyang829 commented 3 years ago

I try to use 2 RTX3090 GPUs to train ResNet18. It seems that it needs nearly 1.5 hour for one epoch, which is over 7 days for 120 epochs. It is such a long time that our group cannot afford this. I also want to know that how long does it cost when you train on both ResNet18 and ResNet50? Thank you!

yikaiw commented 3 years ago

On 8 V100 GPUs, ResNet50 needs 4 days, and ResNet18 only needs 1 day. On 4 GTX1080 GPUs, ResNet18 needs about 2 days. Note that ImageNet data should be stored in the Solid State Disk (SSD), which largely speeds up the training (about twice).

Liuyang829 commented 3 years ago

Thank you very much!

Liuyang829 commented 3 years ago

I still have some problems with dataloader in code. Why don't you apply transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) as usual in other ImageNet classification task? When I was modifying RS-Net code, I met a problem that when I train to the 8th epoch or 10th epoch, the crossentropy loss would turned to nan. I am wondering if the lack of data normalization result in this.

And could you please release the code of Tested at New Resolutions in your ablation study

Thank you very much!

yikaiw commented 3 years ago

Hi, thanks for your interest.

Since we find the performance without normalization already achieves SOTA, we do not apply the normalization. The provided code won't obtain the nan cross-entropy loss. If you modify the code and meet the nan loss in an epoch during training, you could probably reduce the initial learning rate.

The code for testing at new resolutions is provided as below, which basically calibrates BNs according to the new reslution:

from __future__ import print_function

import os, sys, argparse
import warnings, random, shutil, time
from tqdm import tqdm
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import models.imagenet as customized_models
from utils import AverageMeter, mkdir_p
from utils.dataloaders import *
from tensorboardX import SummaryWriter

warnings.filterwarnings('ignore')

default_model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

customized_models_names = sorted(name for name in customized_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(customized_models.__dict__[name]))

for name in customized_models.__dict__:
    if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]):
        models.__dict__[name] = customized_models.__dict__[name]

model_names = default_model_names + customized_models_names

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('-d', '--data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')

parser.add_argument('--lr-decay', type=str, default='step',
                    help='mode for learning rate decay')
parser.add_argument('--step', type=int, default=30,
                    help='interval for learning rate decay in step mode')
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 85, 95, 105],
                    help='decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='LR is multiplied by gamma on schedule.')
parser.add_argument('--warmup', action='store_true',
                    help='set lower initial learning rate to warm up the training')

parser.add_argument('--cardinality', type=int, default=32, help='ResNeXt model cardinality (group).')
parser.add_argument('--base-width', type=int, default=4, help='ResNeXt model base width (number of channels in each group).')
parser.add_argument('--groups', type=int, default=3, help='ShuffleNet model groups')
parser.add_argument('--extent', type=int, default=0, help='GENet model spatial extent ratio')
parser.add_argument('--theta', dest='theta', action='store_true', help='GENet model parameterising the gather function')
parser.add_argument('--excite', dest='excite', action='store_true', help='GENet model combining the excite operator')

parser.add_argument('--sizes', type=int, nargs='+', default=[224, 192, 160, 128, 96],
                    help='input resolutions.')
parser.add_argument('--delta_size', type=int, default=0)
parser.add_argument('--cal-bn', action='store_true')
parser.add_argument('--kd', action='store_true',
                    help='build losses of knowledge distillation across scales')
parser.add_argument('-t', '--kd-type', metavar='KD_TYPE', default='topdown',
                    choices=['topdown', 'direct'])
parser.add_argument('--save-dicts', action='store_true')

args = parser.parse_args()
n_sizes = len(args.sizes)
assert args.delta_size >= 0
for i in range(n_sizes):
    args.sizes[i] += args.delta_size

def main():
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting from checkpoints.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch.startswith('resnext'):
            model = models.__dict__[args.arch](
                    baseWidth=args.base_width,
                    cardinality=args.cardinality,
                )
        elif args.arch.startswith('shufflenetv1'):
            model = models.__dict__[args.arch](
                    groups=args.groups
                )
        elif args.arch.startswith('ge_resnet'):
            model = models.__dict__[args.arch](
                    extent=args.extent,
                    theta=args.theta,
                    excite=args.excite
                )
        elif args.arch.startswith('parallel') or args.arch.startswith('meta'):
            model = models.__dict__[args.arch](num_parallel=n_sizes)
        else:
            model = models.__dict__[args.arch]()

    if args.kd:
        alpha = nn.Parameter(torch.ones(n_sizes, requires_grad=True))
        model.register_parameter('alpha', alpha)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        state_dict = checkpoint['state_dict']
        if args.save_dicts:
            d = dict(state_dict)
            for key in d.keys():
                d[key] = d[key].cpu().numpy()
            np.save('dict.npy', d)
            print('dict saved')
            return
        if args.delta_size != 0 and args.cal_bn:
            state_dict = cal_bn(args.delta_size, state_dict)
        if not args.kd and 'module.alpha' in state_dict:
            del state_dict['module.alpha']
        model.load_state_dict(state_dict, strict=False)
        print('# parameters:', sum(param.numel() for param in model.parameters()))
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
        return

    cudnn.benchmark = True

    get_train_loader, get_val_loader = get_pytorch_train_loader, get_pytorch_val_loader
    val_loader, val_loader_len = get_val_loader(args.data, args.batch_size, args.sizes, workers=args.workers)
    validate(val_loader, val_loader_len, model, criterion)
    summary()

    return

def cal_bn(delta_size, state_dict):
    statistics = ['weight', 'bias', 'running_mean', 'running_var']
    state_dict_copy = state_dict.copy()
    alpha = delta_size / 32
    for key in state_dict.keys():
        for s in statistics:
            for i in range(1, 5):
                if 'bn_%d' % i in key and s in key:
                    key_ = key.replace('bn_%d' % i, 'bn_%d' % (i - 1))
                    state_dict_copy[key] = state_dict[key_] * alpha + state_dict[key] * (1 - alpha)
    return state_dict_copy

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        res.append(correct)

        return res

def validate(val_loader, val_loader_len, model, criterion):
    top1 = [AverageMeter() for _ in range(n_sizes)]
    top5 = [AverageMeter() for _ in range(n_sizes)]
    top1_res = [[]] * n_sizes
    times = []

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in tqdm(enumerate(val_loader), total=val_loader_len):
        target = target.cuda(non_blocking=True)

        with torch.no_grad():
            # compute output
            t = time.time()
            output = model(input)
            times.append(time.time() - t)
            # print('current step time:', times[-1], flush=True)

            for j in range(n_sizes):
                # measure accuracy and record loss
                acc1, acc5, correct = accuracy(output[j], target, topk=(1, 5))
                top1[j].update(acc1.item(), input[0].size(0))
                top5[j].update(acc5.item(), input[0].size(0))
                correct1 = correct[:1].cpu().numpy().tolist()
                top1_res[j] = top1_res[j] + correct1[0]
    print('mean step time:', np.mean(times))

    for j, size in enumerate(args.sizes):
        top1_avg, top5_avg = top1[j].avg, top5[j].avg
        print('size%03d: top1 %.2f, top5 %.2f' % (size, top1_avg, top5_avg))

    with open('top1_val_resnet18_shared_kd.bin','wb') as fp:
        pickle.dump(top1_res,fp)

    return [round(t.avg, 1) for t in top1], [round(t.avg, 1) for t in top5]

def summary():
    with open('top1_val_resnet18_shared_kd.bin','rb') as fp:
        top1_res = pickle.load(fp)

    K = len(top1_res)
    N = len(top1_res[0])

    for i in range(K):
        for j in range(K):
            if i==j:
                continue
            cor_i = np.array(top1_res[i]).astype(np.float32)
            cor_j = np.array(top1_res[j]).astype(np.float32)
            cor_i = 1.0 - cor_i
            _ij = np.multiply(cor_i,cor_j)
            print("(%d,%d) = %f" % (i, j, _ij.sum() / N))

if __name__ == '__main__':
    main()
Liuyang829 commented 3 years ago

Thank you very much!