CarryHJR / LogDet

2022天池商品标志目标检测
89 stars 14 forks source link

划分val的预处理代码 #10

Open CarryHJR opened 2 years ago

CarryHJR commented 2 years ago

这里贴一下划分val的代码

coco = COCO(os.path.join(root, 'annotations/instances_train2017.json'))
class_names = [coco.cats[catId]['name'] for catId in coco.getCatIds()]
categories = [dict(id=i+1, name=name) for i, name in enumerate(class_names)]

annotaions_train = []
images_train = []
annotaions_val = []
images_val = []
for catId in coco.getCatIds():
    imgIds = coco.getImgIds(catIds=[catId])
    random.shuffle(imgIds)
    for imgId in imgIds[:10]:
        img = coco.imgs[imgId]
        images_val.append(img)
        anns = coco.imgToAnns[imgId]
        for ann in anns:
            annotaions_val.append(ann)
    for imgId in imgIds[10:]:
        img = coco.imgs[imgId]
        images_train.append(img)
        anns = coco.imgToAnns[imgId]
        for ann in anns:
            annotaions_train.append(ann)
json_dict_train = {"images": images_train, "type": "instances", "annotations": annotaions_train, "categories": categories}
json_dict_val = {"images": images_val, "type": "instances", "annotations": annotaions_val, "categories": categories}

json_dict_train, json_dict_val 存储一下就可以了