sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
1.06k stars 144 forks source link

How to build my own train_set use own data #64

Closed cy2333ytu closed 2 years ago

cy2333ytu commented 2 years ago

Problem Thanks for your sharing about FSL, there is one problem: When I finished the tutorial 'Discovering Prototypical Networks' , I want to use my own photo data to build test_set, how can I do that and How should I construct my data's structure

ebennequin commented 2 years ago

Hi! This seems to be related to #44 . You can build a FewShotDataset object from custom data using EasySet.

More recently, we added the SupportSetFolder class that helps you to build a support set easily from custom data.

So depending on your use case, it will be one or the other.

cy2333ytu commented 2 years ago

Thanks for your suggest @ebennequin. I use SupportSetFolder to build my dataset, but there is one mistake:'ValueError: Sample larger than population or is negative' My code is as follows:

path = 'D:/G/few_shot_learning/try_code/flowers'

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") train_set = SupportSetFolder(root=path, device=device, image_size=28, transform=trans_train) test_set = SupportSetFolder(root=path, device=device, image_size=28, transform=trans_test) convolutional_network = resnet18(pretrained=True) convolutional_network.fc = nn.Flatten()

model = PrototypicalNetworks(convolutional_network).cuda() N_WAY = 5
N_SHOT = 5
N_QUERY = 5
N_EVALUATION_TASKS = 1

test_sampler = TaskSampler(test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS)

test_loader = DataLoader( test_set, batch_sampler=test_sampler, num_workers=0, pin_memory=True, collate_fn=test_sampler.episodic_collate_fn)

( example_support_images, example_support_labels, example_query_images, example_query_labels, example_class_ids, ) = next(iter(test_loader))

The structure of flowers file is: image There are 8 subfolders in flowers and 20 images in each subfolder

ebennequin commented 2 years ago

This error is risen in TaskSampler (see #44 and #45), usually because there are not enough elements in a class compared to n_shot + n_query, or not enough classes compared to n_way. It doesn't seem to be the case in your situation.

You are using task samplers on a dataset that is not meant to be used to sample few-shot tasks. As its name and docstring indicate, it is only meant to handle the support set of a unique few-shot task. I think the error here might come from the fact that SupportSetFolder.get_labels() return a torch Tensor while TaskSampler needs a get_labels() method returning a list of integers (like EasySet).

I made this choice for SupportSetFolder so the result of get_labels() may be directly fed to a few-shot method. However, I see that it is problematic in the sense that it shares the name of the FewShotDataset.get_labels() method but has a different behavior. I think we can do better so I'm flagging this issue as an enhancement opportunity.

For your issue, I suggest you use EasySet which is meant to allow sampling of various few-shot tasks.

cy2333ytu commented 2 years ago

I have built custom data by using EasySet, thanks very much! @ebennequin