JDAI-CV / DCL

Destruction and Construction Learning for Fine-grained Image Recognition
Other
585 stars 159 forks source link

New test.py #53

Open Lamborghini1709 opened 4 years ago

Lamborghini1709 commented 4 years ago

coding=utf-8

import os import json import csv import argparse import pandas as pd import numpy as np from math import ceil from tqdm import tqdm import pickle import shutil

import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import CrossEntropyLoss from torchvision import datasets, models import torch.backends.cudnn as cudnn import torch.nn.functional as F

from transforms import transforms from models.LoadModel import MainModel from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset from config import LoadConfig, load_data_transformers from utils.test_tool import set_text, save_multi_img, cls_base_acc

import pdb

os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

def parse_args(): parser = argparse.ArgumentParser(description='dcl parameters') parser.add_argument('--data', dest='dataset', default='CUB', type=str) parser.add_argument('--backbone', dest='backbone', default='resnet50', type=str) parser.add_argument('--b', dest='batch_size', default=16, type=int) parser.add_argument('--nw', dest='num_workers', default=16, type=int) parser.add_argument('--ver', dest='version', default='test', type=str) parser.add_argument('--save', dest='resume', default=None, type=str) parser.add_argument('--size', dest='resize_resolution', default=512, type=int) parser.add_argument('--crop', dest='crop_resolution', default=448, type=int) parser.add_argument('--ss', dest='save_suffix', default=None, type=str) parser.add_argument('--acc_report', dest='acc_report', action='store_true') parser.add_argument('--swap_num', default=[7, 7], nargs=2, metavar=('swap1', 'swap2'), type=int, help='specify a range') args = parser.parse_args() return args

if name == 'main': args = parse_args() print(args)

if args.submit:

#     args.version = 'test'
#     if args.save_suffix == '':
#         raise Exception('**** miss --ss save suffix is needed. ')
args.version = 'test'
Config = LoadConfig(args, args.version)
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
data_set = dataset(Config,\
                   anno=Config.val_anno if args.version == 'val' else Config.test_anno ,\
                   # unswap=transformers["None"],\
                   swap=transformers["None"],\
                   totensor=transformers['test_totensor'],\
                   test=True)

dataloader = torch.utils.data.DataLoader(data_set,\
                                         batch_size=args.batch_size,\
                                         shuffle=False,\
                                         num_workers=args.num_workers,\
                                         collate_fn=collate_fn4test)

setattr(dataloader, 'total_item_len', len(data_set))

cudnn.benchmark = True
print('****************')
Config.cls_2xmul = True
print(Config.cls_2xmul)
model = MainModel(Config)
model_dict=model.state_dict()
pretrained_dict=torch.load(args.resume)
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.cuda()
model = nn.DataParallel(model)

model.train(False)
with torch.no_grad():
    val_corrects1 = 0
    val_corrects2 = 0
    val_corrects3 = 0
    val_size = ceil(len(data_set) / dataloader.batch_size)
    result_gather = {}
    count_bar = tqdm(total=dataloader.__len__())
    for batch_cnt_val, data_val in enumerate(dataloader):
        count_bar.update(1)
        inputs, labels, img_name = data_val
        inputs = Variable(inputs.cuda())
        labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())

        outputs = model(inputs)
        # print('outputs:', outputs)
        outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls]
        print('outputs_pred:', outputs_pred)

        top3_val, top3_pos = torch.topk(outputs_pred, 3)

        if args.version == 'val':
            batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
            val_corrects1 += batch_corrects1
            batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
            val_corrects2 += (batch_corrects2 + batch_corrects1)
            batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
            val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)

        if args.acc_report:
            for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()):
                result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2],
                                           'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2],
                                           'label': sub_label}
if args.acc_report:
    result_gather_json = open('result_gather_%s'%args.resume.split('/')[-1][:-4]+ '.json', 'w')
    json.dump(result_gather, result_gather_json)
    result_gather_json.close()
    torch.save(result_gather, 'result_gather_%s'%args.resume.split('/')[-1][:-4]+ '.pt')

count_bar.close()
print(args.acc_report)
if args.acc_report:

    val_acc1 = val_corrects1 / len(data_set)
    val_acc2 = val_corrects2 / len(data_set)
    val_acc3 = val_corrects3 / len(data_set)
    print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-',  val_acc3, 8*'-'))

    cls_top1, cls_top3, cls_count = cls_base_acc(result_gather)

    acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, args.resume.split('/')[-1]), 'w')
    json.dump({'val_acc1':val_acc1,
               'val_acc2':val_acc2,
               'val_acc3':val_acc3,
               'cls_top1':cls_top1,
               'cls_top3':cls_top3,
               'cls_count':cls_count}, acc_report_io)
    acc_report_io.close()

run test : python test.py --save ./net_model/training_descibe_72721_CUB/weights_36_4999_0.8608_0.9998.pth --acc_report

yunchangxiaoguan commented 4 years ago

hello, i use the code ,but error:

`python new_test.py --save net_model/_8514_CUB/weights_20_0_1.0000_1.0000.pth --acc_report Namespace(acc_report=True, backbone='resnet50', batch_size=16, crop_resolution=448, dataset='CUB', num_workers=16, resize_resolution=512, resume='net_model/_8514_CUB/weights_20_0_1.0000_1.0000.pth', save_suffix=None, swap_num=[7, 7], version='test')


True resnet50 Traceback (most recent call last): File "new_test.py", line 95, in model.load_state_dict(model_dict) File "/home/guanxiao/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 847, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for MainModel: size mismatch for classifier_swap.weight: copying a param with shape torch.Size([2, 2048]) from checkpoint, the shape in current model is torch.Size([6, 2048]). ` how to make it?thanks

Lamborghini1709 commented 4 years ago

check out your num_classes

yunchangxiaoguan commented 4 years ago

check out your num_classes

thanks,i have solve it

BaofengZan commented 4 years ago

@Lamborghini1709 @yunchangxiaoguan 你好,能不能分享下训练好的cub模型,这个模型是真的需要硬件。小batch训练太慢。

JiCheng12138 commented 3 years ago

@yunchangxiaoguan 您好请问这个size mismatch for classifier_swap.weight:问题您如何解决的呢,谢谢

Lamborghini1709 commented 3 years ago

@yunchangxiaoguan 您好请问这个size mismatch for classifier_swap.weight:问题您如何解决的呢,谢谢

检查你的输出类别数 num_classes