midasklr / SSD.Pytorch

Pytorch implementation of SSD512
MIT License
51 stars 26 forks source link

A little help for custom data training #1

Open khalidw opened 4 years ago

khalidw commented 4 years ago

Hello! Thank you for making your code public. I am trying to train a custom model for detection of boats. I have approx 3,867 images with VOC annotations (xml files). Can you guide me how I should arrange the data, as in what folder structure should I use?

midasklr commented 4 years ago

Hello! Thank you for making your code public. I am trying to train a custom model for detection of boats. I have approx 3,867 images with VOC annotations (xml files). Can you guide me how I should arrange the data, as in what folder structure should I use?

Hi,just arrange ur dataset as VOC format. and put it at ./data/ ,and change the path in ./data/voc0712.py, I will update the train process in README

khalidw commented 4 years ago

Thank you for your help. I have rearranged the data as per VOC format, following the exact same folder hierarchy. I also made all the changes as per your guide. However, I am stuck again on another error now, I have tried to troubleshoot it before posting it here. Would appreciate if you have a look into it. Thanks again.

PS: I have 13 classes in my boat classification dataset. 'vaporettoactv' is one of the class name, mentioned in the KeyError

image

midasklr commented 4 years ago

Thank you for your help. I have rearranged the data as per VOC format, following the exact same folder hierarchy. I also made all the changes as per your guide. However, I am stuck again on another error now, I have tried to troubleshoot it before posting it here. Would appreciate if you have a look into it. Thanks again.

PS: I have 13 classes in my boat classification dataset. 'vaporettoactv' is one of the class name, mentioned in the KeyError

image

May there is an error in your voc0712.py while parsing your label "vaporettoactv". Please paste your voc0712.py

khalidw commented 4 years ago

This KeyError keeps on changing. However I am sharing the voc0712.py and another screenshot of the error.

`"""VOC Dataset Classes Original author: Francisco Massa https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py Updated by: Ellis Brown, Max deGroot """ from .config import HOME import os.path as osp import sys import torch import torch.utils.data as data import cv2 import numpy as np if sys.version_info[0] == 2: import xml.etree.cElementTree as ET else: import xml.etree.ElementTree as ET

VOC_CLASSES = (# always class 0 'Alilaguna', 'Ambulanza', 'Barchino', 'Boat', 'Lanciafino10mBianca', 'Lanciafino10mMarrone', 'Motobarca', 'Mototopo', 'Patanella', 'Polizia', 'Raccoltarifiuti', 'Topa', 'VaporettoACTV')

VOC_ROOT = osp.join('./', "data/VOCdevkit/")

class VOCAnnotationTransform(object): """Transforms a VOC annotation into a Tensor of bbox coords and label index Initilized with a dictionary lookup of classnames to indexes Arguments: class_to_ind (dict, optional): dictionary lookup of classnames -> indexes (default: alphabetic indexing of VOC's 20 classes) keep_difficult (bool, optional): keep difficult instances or not (default: False) height (int): height width (int): width """

def __init__(self, class_to_ind=None, keep_difficult=False):
    self.class_to_ind = class_to_ind or dict(
        zip(VOC_CLASSES, range(len(VOC_CLASSES))))
    self.keep_difficult = keep_difficult

def __call__(self, target, width, height):
    """
    Arguments:
        target (annotation) : the target annotation to be made usable
            will be an ET.Element
    Returns:
        a list containing lists of bounding boxes  [bbox coords, class name]
    """
    res = []
    for obj in target.iter('object'):
        difficult = int(obj.find('difficult').text) == 1
        if not self.keep_difficult and difficult:
            continue
        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')

        pts = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []
        for i, pt in enumerate(pts):
            cur_pt = int(float(bbox.find(pt).text)) - 1
            # scale height or width
            cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
            bndbox.append(cur_pt)
        label_idx = self.class_to_ind[name]
        bndbox.append(label_idx)
        res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
        # img_id = target.find('filename').text[:-4]

    return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

class VOCDetection(data.Dataset): """VOC Detection Dataset Object input is image, target is annotation Arguments: root (string): filepath to VOCdevkit folder. image_set (string): imageset to use (eg. 'train', 'val', 'test') transform (callable, optional): transformation to perform on the input image target_transform (callable, optional): transformation to perform on the target annotation (eg: take in caption string, return tensor of word indices) dataset_name (string, optional): which dataset to load (default: 'VOC2007') """

def __init__(self, root,
             image_sets=[('2007', 'trainval')],
             transform=None, target_transform=VOCAnnotationTransform(),
             dataset_name='VOC0712'):
    self.root = root
    self.image_set = image_sets
    self.transform = transform
    self.target_transform = target_transform
    self.name = dataset_name
    self._annopath = osp.join('%s', 'Annotations', '%s.xml')
    self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
    self.ids = list()
    for (year, name) in image_sets:
        rootpath = osp.join(self.root, 'VOC' + year)
        for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
            self.ids.append((rootpath, line.strip()))

def __getitem__(self, index):
    im, gt, h, w = self.pull_item(index)

    return im, gt

def __len__(self):
    return len(self.ids)

def pull_item(self, index):
    img_id = self.ids[index]

    target = ET.parse(self._annopath % img_id).getroot()
    img = cv2.imread(self._imgpath % img_id)
    height, width, channels = img.shape

    if self.target_transform is not None:
        target = self.target_transform(target, width, height)

    if self.transform is not None:
        target = np.array(target)
        img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
        # to rgb
        img = img[:, :, (2, 1, 0)]
        # img = img.transpose(2, 0, 1)
        target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
    return torch.from_numpy(img).permute(2, 0, 1), target, height, width
    # return torch.from_numpy(img), target, height, width

def pull_image(self, index):
    '''Returns the original image object at index in PIL form
    Note: not using self.__getitem__(), as any transformations passed in
    could mess up this functionality.
    Argument:
        index (int): index of img to show
    Return:
        PIL img
    '''
    img_id = self.ids[index]
    return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)

def pull_anno(self, index):
    '''Returns the original annotation of image at index
    Note: not using self.__getitem__(), as any transformations passed in
    could mess up this functionality.
    Argument:
        index (int): index of img to get annotation of
    Return:
        list:  [img_id, [(label, bbox coords),...]]
            eg: ('001718', [('dog', (96, 13, 438, 332))])
    '''
    img_id = self.ids[index]
    anno = ET.parse(self._annopath % img_id).getroot()
    gt = self.target_transform(anno, 1, 1)
    return img_id[1], gt

def pull_tensor(self, index):
    '''Returns the original image at an index in tensor form
    Note: not using self.__getitem__(), as any transformations passed in
    could mess up this functionality.
    Argument:
        index (int): index of img to show
    Return:
        tensorized version of img, squeezed
    '''
    return torch.Tensor(self.pull_image(index)).unsqueeze_(0)`

image

midasklr commented 4 years ago

Patanella

there might be some label error in your annotations, KeyError: 'patanella' means there is name:patanella in your xml files while your class name is Patanella, which is sensitive to captial letters, so you sure check your xml files and make sure all label is right .

am-shubh commented 4 years ago

@midasklr @khalidw I was getting the same error. This is how I solved the KeyError: In voc0712.py script, I changed name = obj.find('name').text.lower().strip() to name = obj.find('name').text.strip() inside the call method of VOCAnnotationTransform class.

i.e. removing .lower() method fixed the keyError for me. BTW as suggested above I also cross-checked my annotations but could not find any error. So I think this might be a possible bug.