sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
961 stars 133 forks source link

Low accuracy when using SupportSetFolder. #149

Open TeddyPorfiris opened 1 month ago

TeddyPorfiris commented 1 month ago

Hello! Thanks so much for easyfsl, it's fantastic. I am testing my Prototypical Network (trained on mini imagenet) with SupportSetFolder. When I test it on the folder I attached called dataset1 (containing photos from internet), I get very accurate results. But when I test it on the folder I attached called dataset2 (containing photos I took), I get very inaccurate results. If you could help me figure out why this is, I'd appreciate it so much. Thanks again.

pip install easyfsl

import torch
import os
import csv
from pathlib import Path
import pandas as pd
from skimage import io
from typing import List, Tuple
from PIL import Image
from typing import Optional

from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder, DatasetFolder
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
from easyfsl.methods.utils import compute_prototypes
from easyfsl.datasets import FewShotDataset, WrapFewShotDataset, SupportSetFolder
from easyfsl.methods import FewShotClassifier

class PrototypicalNetworks(FewShotClassifier):
    def __init__(
        self,
        backbone: Optional[nn.Module] = None,
    ):
        """
        Initialize the Prototypical Networks Few-Shot Classifier
        Args:
            backbone: the feature extractor used by the method. Must output a tensor of the
                appropriate shape (depending on the method).
                If None is passed, the backbone will be initialized as nn.Identity().
        """
        super().__init__(backbone=backbone)

    def forward(
        self,
        support_images: torch.Tensor,  # Support images
        support_labels: torch.Tensor,  # Support labels
        query_images: torch.Tensor,    # Query images
    ) -> torch.Tensor:
        """
        Predict classification labels.
        Args:
            support_images: images of the support set of shape (n_support, **image_shape)
            support_labels: labels of support set images of shape (n_support, )
            query_images: images of the query set of shape (n_query, **image_shape)
        Returns:
            a prediction of classification scores for query images of shape (n_query, n_classes)
        """
        # Compute features for support and query images
        z_support = self.compute_features(support_images)
        z_query = self.compute_features(query_images)

        # Compute prototypes from support set
        self.compute_prototypes_and_store_support_set(support_images, support_labels)
        logits = self.l2_distance_to_prototypes(z_query)
        return self.softmax_if_specified(logits)

    @staticmethod
    def is_transductive() -> bool:
        return True  # or False depending on your implementation

# Initialize the backbone (pretrained ResNet18 with the fully connected layer replaced by a Flatten layer)
convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
# print(convolutional_network)

# Create the Prototypical Networks model using resnet18 as the feature extractor CNN

model_path = '/content/MIN_model.pth'
model = PrototypicalNetworks(convolutional_network).cuda()
model.load_state_dict(torch.load(model_path))

device = "cuda"

# Define transformations to be applied to images
transform=transforms.Compose(
    [
        transforms.Resize([348, 348]),
        transforms.CenterCrop(348),
        transforms.ToTensor(),
    ]
)

support_set = SupportSetFolder(root='/content/dataset2/support_set', transform=transform, device=device)

# transform_tensor = transforms.Compose([transforms.ToTensor()])
query_image_path = '/content/dataset2/query_set/chips1.jpg'
query_image_PIL = Image.open(query_image_path)
query_images = transform(query_image_PIL).float()
query_images = query_images.unsqueeze(0)

with torch.no_grad():
    model.eval()
    model.process_support_set(support_set.get_images(), support_set.get_labels())
    class_names = support_set.classes
    print(f"Class names: {class_names}")
    predicted_labels = model(support_set.get_images().cuda(), support_set.get_labels().cuda(), query_images.to(device).cuda()).argmax(dim=1)
    # print(f"Predicted labels: {predicted_labels}")

    predicted_classes = [ support_set.classes[label] for label in predicted_labels]

    print(f"Predicted classes: {predicted_classes}")

Link to download MIN_model.pth: https://drive.google.com/file/d/1q6sfNYcYSTUJzEiHq1T-nJ5R31EZ8dio/view?usp=sharing dataset1.zip dataset1.zip dataset2.zip dataset2.zip

ebennequin commented 1 month ago

There are numerous reasons why your model could perform well on one few-shot task and not on another. The good thing with few-shot learning is that since the volumes are small, it is easier to investigate image-wise what went wrong with the prediction, and why it was associated to a wrong class.

I suggest you use classic investigation tools (confusion matrix, etc...) and complete them with a visualization of poorly classified images and to what particular images from the support set they are the closest.