flychen321 / data_aug_reid

1 stars 0 forks source link

Hi #3

Open ssbilakeri opened 3 years ago

ssbilakeri commented 3 years ago

I followed bellow code but the execution will end with ^c

class SiameseDataset(datasets.ImageFolder): """ Train: For each sample (anchor) randomly chooses a positive and negative samples Test: Creates fixed triplets for testing """

def __init__(self, root, transform):
    super(SiameseDataset, self).__init__(root, transform)
    self.train=True
    if self.train:
        self.train_labels = np.array(self.imgs)[:, 1]
        self.train_data = np.array(self.imgs)[:, 0]
        self.labels_set = set(self.train_labels)
        self.label_to_indices = {label: np.where(self.train_labels == label)[0]
                             for label in self.labels_set}

    else:
        self.test_labels = np.array(self.imgs)[:, 1]
        self.test_data = np.array(self.imgs)[:, 0]
        self.labels_set = set(self.labels)
        self.label_to_indices = {label: np.where(self.test_labels == label)[0]
                             for label in self.test_labels_set}

        random_state = np.random.RandomState(29)

        triplets = [[i,
                     random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                     random_state.choice(self.label_to_indices[
                                             np.random.choice(
                                                 list(self.labels_set - set([self.test_labels[i].item()]))
                                             )
                                         ])
                     ]
                    for i in range(len(self.testdata))]
        self.test_triplets = triplets

def __getitem__(self, index):
    if self.train:
        img1, label1 = self.train_data[index], self.train_labels[index].item()
        positive_index = index
        while positive_index == index:
            positive_index = np.random.choice(self.label_to_indices[label1])
        negative_label = np.random.choice(list(self.labels_set - set([label1])))
        negative_index = np.random.choice(self.label_to_indices[negative_label])
        img2 = self.train_data[positive_index]
        img3 = self.train_data[negative_index]
    else:
        img1 = self.test_data[self.test_triplets[index][0]]
        img2 = self.test_data[self.test_triplets[index][1]]
        img3 = self.test_data[self.test_triplets[index][2]]

    img1 = default_loader(img1)
    img2 = default_loader(img2)
    img3 = default_loader(img3)
    if self.transform is not None:
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        img3 = self.transform(img3)
    return (img1, img2, img3), []

def __len__(self):
    return len(self.imgs)