Open kirk86 opened 3 years ago
M
is the ground metric matrix, which should be defined beforehand according to your prior knowledge of the label space (i.e. how similar label a is to label b).
Both pred
and label
should be k_class
-by- n_eg
. Both pred
and label
should be normalized to 1. So pred
should be softmax output. If you have hard labels, they need to be one-hot encoded, and if there are multiple labels for one instance, the one-hot labels need to re-normalized to sum to 1. Hope this helps.
Thanks that helps!
M
is the ground metric according to your prior knowledge of the label space (i.e. how similar label a is to label b)
So basically M
is just an adjacency matrix between all pairs of labels?
Given k = num_classes, n = num_samples
, where then
?
I'm not sure if I follow your notation. Basically, M
is k x k
shape, where M[i, j]
is the distance between class i
and class j
. The distance depend on specific applications. For examples, if you have class cat, dog, car, etc. Then you might have some ground metric that measures longer distance between cat
and car
than between cat
and dog
. In the most extreme case, where you do not have any information available on the classes, then you can use the uninformative distances to say every class is equal distant to every other class. If this is the case, then this might not be a good application scenario for wasserstein loss any way.
Basically,
M
isk x k
shape
Yep, that's what my notation says as well.
if you have class cat, dog, car, etc. Then you might have some ground metric that measures longer distance between
cat
andcar
than betweencat
anddog
Thanks, do you have any references or examples to such ground metrics?
I've used the l1-norm
between labels to generate M above in my notation.
When I tried training using SGD the model would not train beyond a 40-50% acc.
Adam gave better results but still nowhere near compared to training without the wasserstein, any thoughts?
Hi, I was wondering if anyone could shed some light on the following questions regarding the wasserstein layer.
M
computed, where can I find this info?pred
variable is it logits, softmax outputs or something else?label
is it hard labels like[3, 2, 5, 7, 9, 1, 0]
, one-hot encoded or something else?Do the above vars change according to the problem (multi-class classification), for instance on MNIST with hard labels (i.e. unique labels) vs MNIST with multi-labels for each class?