Closed stoneyang closed 6 years ago
I guess your labels are in one-hot format, with shape (100, 82398)
.
In my code the labels are in integer format like this:
labels = [1, 9, 3, 2, 4, 5, ...]
Also if you have a high number of classes you need to make sure to have multiple instances of the same class in your batch (otherwise no triplet will be found).
cf. issue #7 for the idea of balanced batches
I guess your labels are in one-hot format, with shape
(100, 82398)
.Yes, you are right,
labels
in my code is in one-hot format. In my code the labels are in integer format like this:labels = [1, 9, 3, 2, 4, 5, ...]
I get it. :) Also if you have a high number of classes you need to make sure to have multiple instances of the same class in your batch (otherwise no triplet will be found).
I generated sample list and load them every
batch_size
line. Thanks for your reminder. :)
cf. issue #7 for the idea of balanced batches
Thanks for your pointer and I will be back once this issue get solved.
@omoindrot , I'd solved this problem in my code and all seems fine right now. Just before one_hot_encoding
, the labels
vector is in original form, which is required from triplet loss functions.
This issue seems finished and I will close it right now.
Hi, @omoindrot
Thanks for sharing your code! And I am right now crafting a small project using triplet loss.
I've cloned the code, run it, and all fine. But when applying to my code, in a mini-batch context, code went crack at this line. The shapes of tensors were examined and compared to this mnist example. I found dimension of
label
is the most possible reason.Hope you could shed me some light. I am really stucked.... Thanks in advance.
Following is the error message I got:
Next the shapes of tensors: In my code (batch_size is 100, number of classes is 82398):
In your code: