omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 284 forks source link

Batch hard and balanced batch #45

Open batrlatom opened 5 years ago

batrlatom commented 5 years ago

Hello. I am unable to make work hard triplet mining and balanced batches. I think that we had a discussion about it here, but so far I think embedding are always collapsing into one point. I tried many combinations of "margins", "num_classes_per_batch" , "num_images_per_class". But nothing seems to work. Could you please take a look at the code if there is some obvious problem? Noting that with batch_all strategy, it works well. Thanks, Tom


def train_input_fn(data_dir, params):
    data_root = pathlib.Path(data_dir)
    all_image_paths = list(data_root.glob('**/*.jpg'))
    all_directories = {'/'.join(str(i).split("/")[:-1]) for i in all_image_paths}
    print("-----")
    print("num of labels: ")
    print(len(all_directories))
    print("-----")
    labels_index = list(i.split("/")[-1] for i in  all_directories)

    # Create the list of datasets creating filenames
    datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir), shuffle=False) for image_dir in all_directories]

    num_labels = len(all_directories)
    print(datasets)
    num_classes_per_batch = params.num_classes_per_batch
    num_images_per_class = params.num_images_per_class

    def get_label_index(s):
        return labels_index.index(s.numpy().decode("utf-8").split("/")[-2])

    def preprocess_image(image):   
      image = tf.cast(image, tf.float32)
      image = tf.math.divide(image, 255.0)     
      return image

    def load_and_preprocess_image(path):
        image = tf.read_file(path)
        return tf.py_function(preprocess_image, [image], tf.float32), tf.py_function(get_label_index, [path], tf.int64)

    def generator():
        while True:
            # Sample the labels that will compose the batch
            labels = np.random.choice(range(num_labels),
                                      num_classes_per_batch,
                                      replace=False)
            for label in labels:
                for _ in range(num_images_per_class):
                    yield label

    choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
    dataset = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)

    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    batch_size = num_classes_per_batch * num_images_per_class
    print("----------------------")
    print(batch_size)
    print("----------------------")
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.prefetch(1)

    print(dataset)
    return dataset
omoindrot commented 5 years ago

Don't see anything wrong.

If the batch all loss works, and the batch hard triplet loss does not, this might indicate that your dataset is a bit noisy so hard triplets are mislabeled.

You can also train first with batch all, then finetune at the end with batch hard.