jiawei-ren / BalancedMetaSoftmax-Classification

[NeurIPS 2020] Balanced Meta-Softmax for Long-Tailed Visual Recognition
https://github.com/jiawei-ren/BalancedMetaSoftmax
Other
135 stars 26 forks source link

Meta Sampler #3

Closed lailvlong closed 3 years ago

lailvlong commented 3 years ago

https://github.com/jiawei-ren/BalancedMetaSoftmax-Classification/blob/3712cb9b498797714859198dbd010dd0201ab6dd/data/MetaSampler.py#L62 Hello, thanks for open source your nice work! I have some questions about the code above. Why 'curr_sample' is mutiplied twice?

jiawei-ren commented 3 years ago

In short, the first multiplication with curr_sample.detach() prevents gradient to flow through unsampled instances, i.e., the zeros.

First of all, we explain on the notations. x on the right-hand side is the loss, with shape [B], B for batch size. curr_sample is the sample result, it is a batch of one-hot vectors telling which training image has been selected, with shape [BxN], N for the total number of the training images. In another word, x[i] is the loss for the image whose index is recorded by curr_sample[i]. We multiply them to allow the gradient from loss to flow through the sampling result and then to the sampler (where we applied the Gumble-Softmax trick).

Here, we show why direct multiplication may fail. For the i-th element in the batch, direct multiplication x.unsqueeze(-1) * curr_sample will give us

[curr_sample[i][0] * x[i], curr_sample[i][1] * x[i], ..., curr_sample[i][N-1] * x[i]]

This goes wrong in the following summation. The summation will return

curr_sample[i][0] * x[i] + curr_sample[i][1] * x[i] + ... + curr_sample[i][N-1] * x[i]

which equals to sum(curr_sample[i]) * x[i] and completely disregards the positions of 0 and 1, i.e., not dependent on the one-hot vector curr_sample[i]. As a result, the sampler will get the same gradient however it samples and hence will not learn anything. In fact, the learned unnormalized sample rate will always be 0.5 with the direct multiplication.

Instead, when we first multiply x with curr_sample.detach() then multiply with curr_sample, we will have

[0, 0, ..., curr_sample[i][s] * x[i], ..., 0]

where s is the index of the only 1 in the one-hot vector curr_sample[i]. The later summation will return curr_sample[i][s] * x[i]. This effectively connects the loss with how the sampler samples.

Alternatively, one may also select the non-zero elements from curr_sample then multiply with x.

lailvlong commented 3 years ago

In short, the first multiplication with curr_sample.detach() prevents gradient to flow through unsampled instances, i.e., the zeros.

First of all, we explain on the notations. x on the right-hand side is the loss, with shape [B], B for batch size. curr_sample is the sample result, it is a batch of one-hot vectors telling which training image has been selected, with shape [BxN], N for the total number of the training images. In another word, x[i] is the loss for the image whose index is recorded by curr_sample[i]. We multiply them to allow the gradient from loss to flow through the sampling result and then to the sampler (where we applied the Gumble-Softmax trick).

Here, we show why direct multiplication may fail. For the i-th element in the batch, direct multiplication x.unsqueeze(-1) * curr_sample will give us

[curr_sample[i][0] * x[i], curr_sample[i][1] * x[i], ..., curr_sample[i][N-1] * x[i]]

This goes wrong in the following summation. The summation will return

curr_sample[i][0] * x[i] + curr_sample[i][1] * x[i] + ... + curr_sample[i][N-1] * x[i]

which equals to sum(curr_sample[i]) * x[i] and completely disregards the positions of 0 and 1, i.e., not dependent on the one-hot vector curr_sample[i]. As a result, the sampler will get the same gradient however it samples and hence will not learn anything. In fact, the learned unnormalized sample rate will always be 0.5 with the direct multiplication.

Instead, when we first multiply x with curr_sample.detach() then multiply with curr_sample, we will have

[0, 0, ..., curr_sample[i][s] * x[i], ..., 0]

where s is the index of the only 1 in the one-hot vector curr_sample[i]. The later summation will return curr_sample[i][s] * x[i]. This effectively connects the loss with how the sampler samples.

Alternatively, one may also select the non-zero elements from curr_sample then multiply with x.

I got it , thanks a lot!