Closed gautamb85 closed 6 years ago
Hi @gautamb85,
I'm not sure if I understand you correctly, both query and supports points are extracted by calling the dunder method__getitem__
of your dataset class, the batching function (the sampler) simply returns a list of indexes idx
, so each point has its corresponding label independently of it being support or query.
Check the PyTorch documentation for more info about Dataset, DataLoader and Samplers: http://pytorch.org/docs/master/data.html
Hello,
I have been trying to implement prototypical networks on a different dataset (audio), and I have been having some difficulty with training, my loss seems to be reducing very slowly.
I had a question about your batching function. If we assume there are 32 query points in a batch, and 3-5 support points per class.
Is each query point getting a label (1-32) and is this arbitrary? Or are datapoints given labels based on the whole training set? Put in another way, Does a given query point get the same label (necessarily) in different mini-batches. If you could give me an example that would be super helpful.
Thanks, Gautam