aim-uofa / AdelaiDet

AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.
https://git.io/AdelaiDet
Other
3.37k stars 646 forks source link

Train ABCNet with customized dataset #101

Closed shuangyichen closed 4 years ago

shuangyichen commented 4 years ago

I have prepared my dataset and annotation in json using generate_bezier_json.py. And I run this command for training process mentioned in #100

OMP_NUM_THREADS=1 python tools/train_net.py --config-file configs/BAText/TotalText/attn_R_50.yaml --num-gpus 1

But met this problem,

[06/16 08:29:02 adet.data.detection_utils]: TransformGens used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800, 832, 864, 896), max_size=1600, sample_style='choice')] Traceback (most recent call last): File "/root/detectron2/detectron2/data/catalog.py", line 55, in get f = DatasetCatalog._REGISTERED[name] KeyError: 'mydataset_train' During handling of the above exception, another exception occurred: Traceback (most recent call last): File "tools/train_net.py", line 243, in args=(args,), File "/root/detectron2/detectron2/engine/launch.py", line 57, in launch main_func(*args) File "tools/train_net.py", line 225, in main trainer = Trainer(cfg) File "tools/train_net.py", line 62, in init data_loader = self.build_train_loader(cfg) File "tools/train_net.py", line 128, in build_train_loader return build_detection_train_loader(cfg, mapper) File "/root/detectron2/detectron2/data/build.py", line 333, in build_detection_train_loader proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, File "/root/detectron2/detectron2/data/build.py", line 224, in get_detection_dataset_dicts dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] File "/root/detectron2/detectron2/data/build.py", line 224, in dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] File "/root/detectron2/detectron2/data/catalog.py", line 59, in get name, ", ".join(DatasetCatalog._REGISTERED.keys()) KeyError: "Dataset 'mydataset_train' is not registered! Available datasets are: coco_2014_train, coco_2014_val, coco_2014_minival, coco_2014_minival_100, coco_2014_valminusminival, coco_2017_train, coco_2017_val, coco_2017_test, coco_2017_test-dev, coco_2017_val_100, keypoints_coco_2014_train, keypoints_coco_2014_val, keypoints_coco_2014_minival, keypoints_coco_2014_valminusminival, keypoints_coco_2014_minival_100, keypoints_coco_2017_train, keypoints_coco_2017_val, keypoints_coco_2017_val_100, coco_2017_train_panoptic_separated, coco_2017_train_panoptic_stuffonly, coco_2017_val_panoptic_separated, coco_2017_val_panoptic_stuffonly, coco_2017_val_100_panoptic_separated, coco_2017_val_100_panoptic_stuffonly, lvis_v0.5_train, lvis_v0.5_val, lvis_v0.5_val_rand_100, lvis_v0.5_test, lvis_v0.5_train_cocofied, lvis_v0.5_val_cocofied, cityscapes_fine_instance_seg_train, cityscapes_fine_sem_seg_train, cityscapes_fine_instance_seg_val, cityscapes_fine_sem_seg_val, cityscapes_fine_instance_seg_test, cityscapes_fine_sem_seg_test, voc_2007_trainval, voc_2007_train, voc_2007_val, voc_2007_test, voc_2012_trainval, voc_2012_train, voc_2012_val, pic_person_train, pic_person_val, totaltext_train, totaltext_val, ctw1500_word_train, ctw1500_word_test, syntext1_train, syntext2_train, mltbezier_word_train"

shuangyichen commented 4 years ago

And I wonder what is the number of point to annotate the curved text area.

jiangzz1628 commented 4 years ago

@shuangyichen generate_bezier_json.py在什么地方,能否告知如何生成json文件,谢谢

shuangyichen commented 4 years ago

@shuangyichen generate_bezier_json.py在什么地方,能否告知如何生成json文件,谢谢

https://github.com/aim-uofa/AdelaiDet/blob/master/configs/BAText/README.md Train Your Own Models 下面第二行的链接,然后用你数据集txt格式的标注生成train.json,格式:

24.49,22.09,231.04,18.89,229.73,18.78,436.7,16.86,436.12,68.6,230.59,72.38,230.59,72.38,25.07,76.16||||text 25.07,76.16,284.38,73.13,282.66,72.68,542.51,70.35,543.67,117.44,284.95,120.35,284.95,120.35,26.23,123.26||||text 25.98,121.84,282.61,119.3,280.6,118.45,537.86,116.86,539.6,157.56,283.79,160.76,283.79,160.76,27.98,163.95||||text 29.72,168.02,285.94,164.15,284.76,163.48,541.35,161.05,543.09,202.33,286.41,205.52,286.41,205.52,29.72,208.72||||text 27.98,213.95,287.74,210.62,285.56,209.03,546.0,206.98,546.58,248.26,286.99,254.07,286.99,254.07,27.4,259.88||||text 27.4,265.7,282.33,260.69,280.67,259.16,536.12,254.65,537.28,297.09,282.34,300.58,282.34,300.58,27.4,304.07||||text 25.65,303.49,285.68,303.95,283.59,301.72,544.26,303.49,544.26,344.77,284.95,344.77,284.95,344.77,25.65,344.77||||text 18.67,406.4,75.54,411.42,73.47,400.16,130.88,406.4,130.88,448.84,74.78,448.84,74.78,448.84,18.67,448.84||||text 111.12,484.88,332.01,484.72,330.61,482.68,551.81,484.88,551.81,536.05,331.47,536.05,331.47,536.05,111.12,536.05||||text 23.3,356.4,199.56,348.83,376.0,348.73,552.37,347.67,553.53,390.12,376.6,393.01,199.63,392.93,22.72,397.09||||text

jiangzz1628 commented 4 years ago

@shuangyichen 是的,但是链接我这边打不开,能否用其他形式给我一份,谢谢

shuangyichen commented 4 years ago

@shuangyichen 是的,但是链接我这边打不开,能否用其他形式给我一份,谢谢

`#!/usr/bin/python

-- coding: utf-8 --

import json import os import sys import cv2 import numpy as np from shapely.geometry import *

cV2 = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~']

if len(sys.argv) < 3: print("Usage: python convert_to_detectron_json.py root_path phase split") print("For example: python convert_to_detectron_json.py data train 100200") exit(1) root_path = sys.argv[1] phase = sys.argv[2] split = int(sys.argv[3]) dataset = { 'licenses': [], 'info': {}, 'categories': [], 'images': [], 'annotations': [] } with open(os.path.join(root_path, 'classes.txt')) as f: classes = f.read().strip().split() for i, cls in enumerate(classes, 1): dataset['categories'].append({ 'id': 1, 'name': 'text', 'supercategory': 'beverage', 'keypoints': ['mean', 'xmin', 'x2', 'x3', 'xmax', 'ymin', 'y2', 'y3', 'ymax', 'cross'] # only for keypoints })

def get_category_id(cls): for category in dataset['categories']: if category['name'] == cls: return category['id']

_indexes = sorted([f.split('.')[0] for f in os.listdir(os.path.join(root_path, 'becan_gen_labels'))]) print(_indexes) if phase == 'train': indexes = [line for line in _indexes if int( line) >= split] # only for this file else: indexes = [line for line in _indexes if int(line) <= split] j = 1 for index in _indexes: print('Processing: ' + index) im = cv2.imread(os.path.join(root_path, 'ctwtrain_textimage/') + index + '.jpg') height, width, = im.shape dataset['images'].append({ 'coco_url': '', 'date_captured': '', 'file_name': index + '.jpg', 'flickrurl': '', 'id': int(index.split('')[0]), 'license': 0, 'width': width, 'height': height }) anno_file = os.path.join(root_path, 'becan_gen_labels/') + index + '.txt'

with open(anno_file) as f: lines = [line for line in f.readlines() if line.strip()] for i, line in enumerate(lines): pttt = line.strip().split('||||') parts = pttt[0].split(',') ct = pttt[-1].strip()

  cls = 'text'
  segs = [float(kkpart) for kkpart in parts[:16]]  

  xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
  yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]
  xmin = min([xt[0],xt[3],xt[4],xt[7]])
  ymin = min([yt[0],yt[3],yt[4],yt[7]])
  xmax = max([xt[0],xt[3],xt[4],xt[7]])
  ymax = max([yt[0],yt[3],yt[4],yt[7]])
  width = max(0, xmax - xmin + 1)
  height = max(0, ymax - ymin + 1)
  if width == 0 or height == 0:
    continue

  max_len = 100
  recs = [len(cV2)+1 for ir in range(max_len)]

  ct =  ct.decode('utf-8')
  print('ct', ct)

  for ix, ict in enumerate(ct):        
    if ix >= max_len: continue
    if ict in cV2:
        recs[ix] = cV2.index(ict)
    else:
      recs[ix] = len(cV2)

  dataset['annotations'].append({
      'area': width * height,
      'bbox': [xmin, ymin, width, height],
      'category_id': get_category_id(cls),
      'id': j,
      'image_id': int(index),
      'iscrowd': 0,
      'bezier_pts': segs,
      'rec': recs
  })
  j += 1

folder = os.path.join(root_path, 'annotations') if not os.path.exists(folder): os.makedirs(folder) json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase)) with open(json_name, 'w') as f: json.dump(dataset, f)`

jiangzz1628 commented 4 years ago

OK,谢谢

MaoShouren commented 3 years ago

我遇到了和您一样的问题,请问您是怎么解决的

zmm288 commented 2 years ago

你好,请问如何获得这种八个坐标的txt标注