Closed swapnil-chhatre closed 1 year ago
This is interesting. I'm not sure we would want the TaskSampler
to call the dataset's __get_item__()
for all instances in the dataset, since
Can you give more details about why this is a problem for you? What is the reason why you can't create a get_labels()
method for your dataset object?
The dataset which I am using is not a torchvision.dataset object like Omniglot. Hence it does not have a get_labels() method. This forced me to manually convert the dataset into a Tuple[Tensor, int] format. As rightly pointed out by you, I am now working on creating a custom get_labels() method for my dataset.
The reason for posting this "enhancement" was that such a generalized sampler will enable users to work with a variety of datasets for FSL across multiple domains.
I am also facing the same problem. I am unable to run EasyFSL code on my own dataset. Please suggest how to run EasyFSL on my own dataset using the dataset path address.
Possible solution to explore: adding an option to TaskSampler's init to iterate over the dataset if said dataset does not implement a get_labels()
method.
Could also allow to configure the position of the label in the tuple returned by the dataset's __get_item__()
I've come up with the final version for this:
A method make_few_shot_dataset(dataset: Dataset, image_position_in_get_item_output: int = 0, label_position_in_get_item_output: int = 1) -> FewShotDataset
in the utils that takes any Dataset
object a returns a new object implementing FewShotDataset
:
__get_item__() -> tuple[Tensor, int]
labels
field, and implement the get_labels()
method from this
Problem: I am following the Prototypical Networks tutorial for a personal project. However, I am using a different dataset that does not belong to the FewShotDataset category as mentioned in the definition of TaskSampler. The TaskSampler function requires a dataset with a "label" attribute, which obviously is missing from my dataset. Thus I am not able to sample my dataset if I am following the same code flow as in the tutorial.
Desired Solution: The get_label() method from TaskSampler should be able to extract labels from any datasets which is fed in the format of Tuple[tensor, int], where int is the label of the image represented in the form of tensor.