KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6.02k stars 658 forks source link

Few-shot learning #661

Closed westbalon closed 1 year ago

westbalon commented 1 year ago

Hi, could you please give any recommendation for few-shot learning using this library?

I have a dataset of 150 classes and 36k images for pre-training, then support set of 5 classes with 15 images per class, and I can't get good results yet.

I tried multiple losses, training with classifier, pre-training on main dataset (and then transfer learning with 75 images), combining support with main dataset, performing KNN classification based on centroids after training. For now I get best results with this code, but still not good enough, please have a look, maybe you have some thoughts especially on training with 15 images per class

    if train_transforms:
        transform = A.Compose(
            [
                A.Resize(224,224),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=20, p=0.75),
                #A.RandomCrop(height=224, width=224),
                A.MotionBlur(p=0.75),
                A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.75),
                A.RandomBrightnessContrast(p=0.75),
                A.OneOf([
                          A.RandomSunFlare(p=1),
                          A.CoarseDropout(max_holes=2, max_height=15, max_width=15, min_holes=1, min_height=5, min_width=10, p=1),     
                ], p=0.75), 
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ]
        )
    else:
        transform = A.Compose(
            [
                A.Resize(224,224),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ]
        )

    trunk = resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = nn.Identity()
    embedder = MLP([trunk_output_size, 1024]).to(device)

    loss_funcs = {
        'metric_loss': losses.SupConLoss(temperature=0.1)
    }
    optimizers = {
        'trunk_optimizer': torch.optim.Adam(models['trunk'].parameters(), lr=0.00001,
                                                      weight_decay=0.0001),
        'embedder_optimizer': torch.optim.Adam(models['embedder'].parameters(), lr=0.0001,
                                                         weight_decay=0.0001)
    }
    lr_schedulers = {
        'trunk_scheduler_by_plateau': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizers['trunk_optimizer'], patience=15),
        'embedder_scheduler_by_plateau':   torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizers['embedder_optimizer'], patience=10) 
    }

    batch_size = 128
    dataset_dict = {'val': val_dataset}

    mining_funcs = {'tuple_miner': miners.MultiSimilarityMiner(epsilon=0.1)}

    hooks = logging_presets.get_hook_container(record_keeper)
    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        dataloader_num_workers=16,
        accuracy_calculator=AccuracyCalculator(k='max_bin_count'),
    )
    end_of_epoch_hook = hooks.end_of_epoch_hook(
        tester, dataset_dict, model_folder, test_interval=5, patience=20
    )
    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        train_dataset,
        sampler=samplers.MPerClassSampler(train_dataset.targets, m=1, length_before_new_iter=len(train_dataset)),
        mining_funcs=mining_funcs,
        dataloader_num_workers=params['workers'],
        lr_schedulers=lr_schedulers,    
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook,
    )

........

    self.knn = FaissKNN(reset_before=False, reset_after=False)
    self.knn.train(self.centroids)

    def inference_knn(self, query):
        thr = 0.4
        inference_model = InferenceModel(trunk=self.models['trunk'], embedder=self.models['embedder'], knn_func = self.knn)
        query = self.val_dataset.transform(image=query)['image'].unsqueeze(0)
        distance, index = inference_model.get_nearest_neighbors(query, k=1)
        if distance <= thr:
            return index 
KevinMusgrave commented 1 year ago

Apologies for the late response. I don't have much insight into few shot learning. I wonder if this other library could help: https://github.com/sicara/easy-few-shot-learning

By the way, what kind of performance are you getting so far?