Closed min9kwak closed 2 years ago
Hi @min9kwak, thanks for your interest.
Yes your observation is correct, but this is actually the correct behaviour. To clarify, ClassStratifiedSampler
should sample torch.cat[[a,b,c,...], [a,b,c,...], ..., [a,b,c,...]]
, but where [a,b,c,...]
are randomly sampled classes in each iteration.
Since the loss is non-parametric, the use of label_matrix
is just to identify which samples belong to the same class. For example, if on CIFAR10 you set classes_per_batch: 7
in your config, then in each iteration, ClassStratifiedSampler
will only sample 7 random classes in each iteration, and label_matrix
will be of shape [num_support_imgs x 7]
so in that case you cannot compare it to sdata[-1]
, which is of shape [num_support_imgs x 10]
.
@MidoAssran Thanks for your quick and clear explanation.
As I understood, regardless of whether the order is [a, b, c, a, b, c, ...]
or [b, c, a, b, c, a, ...]
, it is okay to use the label_matrix
because it doesn't really have to indicate the certain class information. It's like a cluster assignment for every iteration or batch.
It was really helpful. Thank you!!
Yes precisely :)
Thank you for the code.
I am trying to reproduce the results and found some strange things in the code. I ran main.py with cifar10_train.yaml. As demonstrated, I thought that
ClassStratifiedSampler
should produce[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... ]
, I believe that it makes us uselabel_matrix
for convenience. However, when I checked the actual labels (sdata[-1]
on line 280 of paws_train.py), the labels were just sampled in a stratified manner, but were not in the appropriate order. For example,[1, 8, 7, 9, 5, 2, 6, 0, 4, 3]
is repeated.sdata[-1]
andlabels
on paws_train.py do not match.... Nevertheless, the model is trained properly. Did I miss something? Thank you in advance.