orobix / Prototypical-Networks-for-Few-shot-Learning-PyTorch

Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch
MIT License
986 stars 210 forks source link

Question about batching function #1

Closed gautamb85 closed 6 years ago

gautamb85 commented 6 years ago

Hello,

I have been trying to implement prototypical networks on a different dataset (audio), and I have been having some difficulty with training, my loss seems to be reducing very slowly.

I had a question about your batching function. If we assume there are 32 query points in a batch, and 3-5 support points per class.

Is each query point getting a label (1-32) and is this arbitrary? Or are datapoints given labels based on the whole training set? Put in another way, Does a given query point get the same label (necessarily) in different mini-batches. If you could give me an example that would be super helpful.

Thanks, Gautam

dnlcrl commented 6 years ago

Hi @gautamb85,

I'm not sure if I understand you correctly, both query and supports points are extracted by calling the dunder method__getitem__ of your dataset class, the batching function (the sampler) simply returns a list of indexes idx, so each point has its corresponding label independently of it being support or query.

Check the PyTorch documentation for more info about Dataset, DataLoader and Samplers: http://pytorch.org/docs/master/data.html