Open abderrazzak555 opened 5 years ago
class KittiLoader(data.Dataset):
def __init__(self, root, split="training",
img_size=512, transforms=None, target_transform=None):
self.root = root
self.split = split
self.target_transform = target_transform
self.n_classes = 2
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.files = collections.defaultdict(list)
self.labels = collections.defaultdict(list)
self.transforms = transforms
self.name = 'kitti'
for split in ["training", "testing"]:
file_list = glob(os.path.join(root, split, 'image_2', '*.png'))
self.files[split] = file_list
if not split=='testing':
label_list=glob(os.path.join(root, split, 'label_2', '*.txt'))
self.labels[split] = label_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_name = self.files[self.split][index]
img_path = img_name
# img = m.imread(img_path)
img = cv2.imread(img_path)
height, width, channels = img.shape
# img = np.array(img, dtype=np.uint8)
if self.split != "testing":
lbl_path = self.labels[self.split][index]
lbl_lines = open(lbl_path, 'r').readlines()
if self.target_transform is not None:
target = self.target_transform(lbl_lines, width, height)
else:
lbl = None
# if self.is_transform:
# img, lbl = self.transform(img, lbl)
if self.transforms is not None:
target = np.array(target)
img, boxes, labels = self.transforms(img, target[:, :4], target[:, 4])
# img, lbl = self.transforms(img, lbl)
img = img[:, :, (2, 1, 0)]
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
if self.split != "testing":
# return img, lbl
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
else:
return img
You should make sure that you have download the datas and place them at the right place.
I encountered the same problem with you. My bug is in __len__(self)
whose return value is larger then it should be.
hello @LongyuanCode can you please elaborate more on how you tackled the problem, cause I'm stuck with the same issue. Thanks
def main(): from torchvision import transforms from torch.utils import data
if name == 'main': main() ``