sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
1.06k stars 144 forks source link

Generalizing TaskSampler for sampling other datasets #82

Closed swapnil-chhatre closed 1 year ago

swapnil-chhatre commented 1 year ago

Problem: I am following the Prototypical Networks tutorial for a personal project. However, I am using a different dataset that does not belong to the FewShotDataset category as mentioned in the definition of TaskSampler. The TaskSampler function requires a dataset with a "label" attribute, which obviously is missing from my dataset. Thus I am not able to sample my dataset if I am following the same code flow as in the tutorial.

Desired Solution: The get_label() method from TaskSampler should be able to extract labels from any datasets which is fed in the format of Tuple[tensor, int], where int is the label of the image represented in the form of tensor.

ebennequin commented 1 year ago

This is interesting. I'm not sure we would want the TaskSampler to call the dataset's __get_item__() for all instances in the dataset, since

  1. It could be very costly
  2. It would still require a specific shape to be defined for the dataset (here that an item is a tuple of size 2 with the label at position 1).

Can you give more details about why this is a problem for you? What is the reason why you can't create a get_labels() method for your dataset object?

swapnil-chhatre commented 1 year ago

The dataset which I am using is not a torchvision.dataset object like Omniglot. Hence it does not have a get_labels() method. This forced me to manually convert the dataset into a Tuple[Tensor, int] format. As rightly pointed out by you, I am now working on creating a custom get_labels() method for my dataset.

The reason for posting this "enhancement" was that such a generalized sampler will enable users to work with a variety of datasets for FSL across multiple domains.

Anin1 commented 1 year ago

I am also facing the same problem. I am unable to run EasyFSL code on my own dataset. Please suggest how to run EasyFSL on my own dataset using the dataset path address.

ebennequin commented 1 year ago

Possible solution to explore: adding an option to TaskSampler's init to iterate over the dataset if said dataset does not implement a get_labels() method. Could also allow to configure the position of the label in the tuple returned by the dataset's __get_item__()

ebennequin commented 1 year ago

I've come up with the final version for this:

What

A method make_few_shot_dataset(dataset: Dataset, image_position_in_get_item_output: int = 0, label_position_in_get_item_output: int = 1) -> FewShotDataset in the utils that takes any Dataset object a returns a new object implementing FewShotDataset: