Felix-Petersen / difftopk

Differentiable Top-k Classification Learning
MIT License
71 stars 3 forks source link

Multi Label #8

Open vateye opened 6 months ago

vateye commented 6 months ago

Hi, is it possible to use top-k entropy loss for multi-label classification problem? Which each of the gt_label can be the top-K?

Felix-Petersen commented 6 months ago

Hi, yes, that should be possible.

In the forward (https://github.com/Felix-Petersen/difftopk/blob/76ef96db648058a73571628f1db5e6a9f4478bfd/difftopk/losses.py#L86), it would primarily require replacing the losses of this style

torch.nn.functional.nll_loss(torch.log(topk_distribution * (1 - 2e-7) + 1e-7), labels)

to be replaced by something like:

l = 0
for label_idx in range(num_labels):
    l = l + torch.nn.functional.nll_loss(torch.log(topk_distribution * (1 - 2e-7) + 1e-7), labels[:, label_idx])
l = l / num_labels

(This is for num_labels labels, represented using labels being a batch_size x num_labels LongTensor.)

For the optional self.m is not None part of the forward, keep in mind that this would also require some adjustment.

I'm interested in your results. If you can, please share your results, I'm happy to help.

vateye commented 6 months ago

Thanks for your explanation. Here is my case, actually I have a K-Hot label, and when K=1, I would do softmax and cross-entropy for training and multinomial for sampling. Now when K > 1, the softmax operator would suppress the Top-K probability (Pr(Top-k) <= 1/K). So I would think this would be the K-Hot to be the top-k prediction problem. But I cannot find if this is the case for topk_distribution since we will maximize the probability of sum_i^k sum_j^k P[i, j] to be maximized.

Felix-Petersen commented 6 months ago

In this case, it would be something like

- (torch.log(topk_distribution * (1 - 2e-7) + 1e-7) * labels).mean(0).sum(-1) / labels.sum(-1)

where labels is a k-hot FloatTensor of shape batch_size x num_labels. The above equation supports different k for different elements of the batch, and reweighs them correspondingly. In extended case of using the m (which is helpful for large numbers of classes / recommended for >100 classes), the respective selection will require additional considerations for ensuring that the "reduced m classes" are always containing the top-k and also all remaining m-k top predicted scoring classes.

Is your k the same for all elements of the batch? If not, you could get topk_distribution (with assuming p_k=[0s..., 1, 0s...] with the 1 at index k for each element) via

topk_distribution = (P_topk * labels.sort(dim=-1, descending=False)[0].unsqueeze(1)[:, :, -self.k:]).sum(-1)

where self.k is the maximum k to be considered (as in the code). Here, the sort has the purpose of producing k-hot vectors that each have the last k entries being 1.

If k is the same for each elem in the batch, you can do

topk_distribution = P_topk[:, :, -self.k:].sum(-1)
vateye commented 6 months ago

Yes, here is my implementation.

marginal_probs = [1/K] * K
scores = scores.unsqueeze(-1)
sorted_scores = scores.topk(k=K, dim=1, largest=True)[0]

soft_permutation = (scores.transpose(1, 2) - sorted_scores).abs().pow(power).neg() / tau
soft_permutation = soft_permutation.softmax(-1) # [B, K, M]

topk_distribution = 0
for k, p_k in enumerate(marginal_probs):
    if p_k == 0:
        continue
    topk_distribution += p_k * soft_permutation[:, :(k+1)].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)

Please correct me if I made mistake. Besides, if it is a generation problem. How should I do stochastic sampling during inference?

Felix-Petersen commented 6 months ago

Hi, this looks like you are using NeuralSort or SoftSort. I recommend Cauchy Odd-Even Differentiable Sorting Networks for better performance.

In soft_permutation[:, :(k+1)].sum(1), you are summing over the first entries, which seems wrong unless you use a convention opposite of the convention in the difftopk library (in 2 ways: negative order, and transposition).

Considering that your K is constant, it should simply be

topk_distribution = P_topk[:, :, -K:].sum(-1)
- (torch.log(topk_distribution * (1 - 2e-7) + 1e-7) * labels).mean(0).sum(-1)

where P_topk is computed as in this line: https://github.com/Felix-Petersen/difftopk/blob/76ef96db648058a73571628f1db5e6a9f4478bfd/difftopk/losses.py#L120

vateye commented 6 months ago

I think I utilized the convention opposite of the convention in the difftopk library. For example, the shape of soft_permutation would be [B, K, N], and the soft_permutation[:, i] would be the probability of being top-i? Is that correct?

Felix-Petersen commented 6 months ago

No, the probability of something being top-i, i.e., among the top i elements is: topk_distribution = P_topk[:, :, -K:].sum(-1). In your convention, probably topk_distribution = soft_permutation[:, :K, :].sum(-2).

It's important to use "among the top-k". If not, it will be exactly the kth largest, implying an ordering, and the loss doesn't make sense.

vateye commented 6 months ago

I am still have little confused. Since in my implementation. I used the for loop from top-1 to top-K, where for each iteration k, I apply the topk_distribution += p_k * soft_permutation[:, :(k+1)].sum(1) to calculate the accumlated probability of 'among top-k', and p_k is the probability of being among top-k?

So should I directly use topk_distribution = soft_permutation[:, :K].sum(1) without for loop iteration? Please correct me if I misunderstood.

Is the following updated code snippet correct?

scores = scores.unsqueeze(-1)
sorted_scores = scores.topk(k=K, dim=1, largest=True)[0]

soft_permutation = (scores.transpose(1, 2) - sorted_scores).abs().pow(power).neg() / tau
soft_permutation = soft_permutation.softmax(-1) # [B, K, M]

topk_distribution = soft_permutation[:, :K].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)
Felix-Petersen commented 6 months ago

This part

topk_distribution = soft_permutation[:, :K].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)

seems correct assuming you input a respectively correct soft_permutation. Again, I'd recommend going with DSNs and integrating the "m" trick for best performance.

Yes, without a loop, the sum is correct. But if you use a loop, you actually still need the sum, just that you don't need a loop because your p_k weights should be 0 for the K-1 first places and then 1 (because you care about top-K and not about some percentage top-1, some percentage top-2, etc. )

Felix-Petersen commented 6 months ago

In your application, what's your number of classes, and what is K?

vateye commented 6 months ago

In your application, what's your number of classes, and what is K?

The number of classes would be depending on the vocab size (i.e., 1k, 4k, 16k, 64k). And the K usually is [1, 2, 4, 8]. Since my goal is to predict an unordered set sized K from the classes.

Felix-Petersen commented 6 months ago

In this case, I'd strongly recommend DSNs, and setting m to something like 32, 50, or 64, which empirically stabilizes training drastically.

vateye commented 6 months ago

I don't know much about the DSNs. Since I prefer the simplicity for loss function, that is why I choose softsort. Would you mind explaining what is DSNs and how it works better than softsort?

Felix-Petersen commented 6 months ago

Sorry for the delay in response.

Differentiable Sorting Networks are a differentiable relaxation of the classic sorting algorithm called "Sorting Networks". Especially monotonic DSNs (like Cauchy DSN) provide an improved gradient quality and better optimization behavior compared to algorithms like SoftSort. I think it will be easiest to understand via my videos on the topic:

https://www.youtube.com/watch?v=38dvqdYEs1o (original DSNs) https://www.youtube.com/watch?v=Rl-sFaE1z4M (monotonic DSNs extension, animated)

Feel free to ask questions about it here if anything is unclear.