Felix-Petersen / difftopk

Differentiable Top-k Classification Learning
MIT License
66 stars 3 forks source link

About difftopk and torch.topk #9

Open lijun2005 opened 2 months ago

lijun2005 commented 2 months ago

Great job! This is the first time I've seen a differentiable top-k selection in an engineering project so well-implemented. But I have a question for the author: What are the differences between the top-k implementation in this project and the one built into PyTorch (torch.topk )?

Felix-Petersen commented 2 months ago

Thank you!

The topk built into PyTorch (torch.topk) is not differentiable whereas difftopk is differentiable. The indices in torch.topk are integers and do not have an attached gradient. By returning "indices" in the form of a topk assignment matrix having probabilities, difftopk enables optimization with respect to "which elements are in the set of topk". Further, while the values return value in torch.topk has attached gradients, in many optimization problems, using smooth rather than just continuous values (as provided in difftopk) also improves training.

If you are interested only in the topk aspects (rather than the TopKCrossEntropyLoss), see the documentation of difftopk.TopKCrossEntropyLoss (https://github.com/Felix-Petersen/difftopk?tab=readme-ov-file#difftopknet) or differentiable sorting networks in general (https://github.com/Felix-Petersen/diffsort).