microsoft / Relation-Aware-Global-Attention-Networks

We design an effective Relation-Aware Global Attention (RGA) module for CNNs to globally infer the attention.
MIT License
341 stars 65 forks source link

with marker dataset #6

Open rrjia opened 4 years ago

rrjia commented 4 years ago

Evaluated with "feat_" features and "cosine" metric: Mean AP: 0.4% CMC Scores top-1 0.1% top-5 0.4% top-10 0.5% Evaluated with "feat" features and "cosine" metric: Mean AP: 0.3% CMC Scores top-1 0.1% top-5 0.2% top-10 0.5%

market.py as fellow from future import absolute_import from future import division from future import print_function

import os import glob import re import sys import urllib import tarfile import zipfile import os.path as osp from scipy.io import loadmat import numpy as np import h5py from scipy.misc import imsave

from ..utils.osutils import mkdir_if_missing from ..utils.serialization import write_json, read_json

class Market(object): """ Args: split_id (int): split index (default: 0) cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) """ dataset_dir = 'cuhk03'

dataset_dir = 'CUHK03_New'

def __init__(self, root='data', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, verbose=True, **kwargs):
    super(Market, self).__init__()
    self.dataset_name = "market"
    self.num_train_pids = 0
    self.num_train_imgs = 0
    self.num_query_pids = 0
    self.num_query_imgs = 0
    self.num_gallery_pids = 0
    self.num_gallery_imgs = 0

    ##############################################################
    # 以下代码是自己的
    data_dir = "/data/server77_data/rrjia/Market-1501-v15.09.15"
    # data_dir = "Z:\\rrjia\\Market-1501-v15.09.15"
    if osp.isdir(data_dir):
        self.data_dir = data_dir

    self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
    self.query_dir = osp.join(self.data_dir, 'query')
    self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')

    required_files = [
        self.data_dir,
        self.train_dir,
        self.query_dir,
        self.gallery_dir,
    ]
    self.check_before_run(required_files)

    train = self.process_dir(self.train_dir)
    query = self.process_dir(self.query_dir, is_train=False)
    gallery = self.process_dir(self.gallery_dir, is_train=False)

    self.train = train
    self.query = query
    self.gallery = gallery

    if verbose:
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # ids | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(self.num_train_pids, self.num_train_imgs))
        print("  query    | {:5d} | {:8d}".format(self.num_query_pids, self.num_query_imgs))
        print("  gallery  | {:5d} | {:8d}".format(self.num_gallery_pids, self.num_gallery_imgs))
        print("  ------------------------------")

def check_before_run(self, required_files):
    """Checks if required files exist before going deeper.
    Args:
        required_files (str or list): string file name(s).
    """
    if isinstance(required_files, str):
        required_files = [required_files]

    for fpath in required_files:
        if not os.path.exists(fpath):
            raise RuntimeError('"{}" is not found'.format(fpath))

def process_dir(self, dir_path, is_train=True):
    img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
    pattern = re.compile(r'([-\d]+)_c(\d)')

    data = []
    pid_set = set()
    image_num = 0
    pid_list = list()
    for img_path in img_paths:
        pid, camid = map(int, pattern.search(img_path).groups())
        if pid == -1:
            continue  # junk images are just ignored
        assert 0 <= pid <= 1501  # pid == 0 means background
        assert 1 <= camid <= 6
        camid -= 1  # index starts from 0
        if pid in pid_list:
            pass
        else:
            pid_list.append(pid)
        new_pid = pid_list.index(pid)
        pid_set.add(new_pid)
        image_num += 1
        data.append((img_path, new_pid, camid))

    if dir_path.endswith("bounding_box_train"):
        self.num_train_pids = len(pid_set)
        self.num_train_imgs = image_num
    elif dir_path.endswith("query"):
        self.num_query_pids = len(pid_set)
        self.num_query_imgs = image_num
    else:
        self.num_gallery_pids = len(pid_set)
        self.num_gallery_imgs = image_num
    return data
guanyonglai commented 4 years ago

are you solve this problem?

rrjia commented 4 years ago

I didn't peruse the paper and gave up analyzing it

guanyonglai commented 4 years ago

nice