xingyizhou / GTR

Global Tracking Transformers, CVPR 2022
375 stars 57 forks source link

about lvis version #31

Open HanGuangXin opened 2 years ago

HanGuangXin commented 2 years ago

Hi there! Thanks for your work.

Here I have 2 questions about the version of lvis dataset:

  1. Why did you use v1.0 instead of v0.5?
  2. Could you please show me the code which re-map the labels of v1.0 back to v0.5?

Looking forward to your reply!

xingyizhou commented 2 years ago

Hi,

  1. We used LVIS v1.0. We mapped the labels back to v0.5 when doing test set evaluation.
  2. Please find my (un-cleaned) script below:
import argparse
import json
import os
from detectron2.structures import Boxes, BoxMode

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pred', default='')
    parser.add_argument('--ann', default='')
    parser.add_argument('--convert_v05', default='')
    parser.add_argument('--minus1', action='store_true')
    parser.add_argument('--plus1', action='store_true')
    args = parser.parse_args()

    print('Loading', args.ann)
    data = json.load(open(args.ann, 'r'))
    print('Done')

    print('Loading', args.convert_v05)
    data_v05 = json.load(open(args.convert_v05, 'r'))
    cats_v05 = data_v05['categories']
    catid2synset = {x['id']: x['synset'] for x in data['categories']}
    synset2v05 = {x['synset']: x['id'] for x in cats_v05}
    catid2v05 = {x['id']: synset2v05[catid2synset[x['id']]] \
        for x in data['categories'] if catid2synset[x['id']] in synset2v05}

    print('Loading', args.pred)
    if args.pred.endswith('.pth'):
        import torch
        pred_data = torch.load(args.pred)
        preds = []
        # import pdb; pdb.set_trace()
        for x in pred_data:
            preds.extend(x['instances'])
    else:
        preds = json.load(open(args.pred, 'r'))
    print('Done')

    out_path = args.pred[:-5] + '_v05.json'
    ret = []
    for x in preds:
        cat_id = x['category_id']
        if args.minus1:
            cat_id = cat_id - 1
        if args.plus1:
            cat_id = cat_id + 1
        if cat_id in catid2v05:
            cat_id = catid2v05[cat_id]
        else:
            continue
        x['category_id'] = cat_id
        ret.append(x)
    print('Writing to', out_path)
    json.dump(ret, open(out_path, 'w'))