Open lijun2005 opened 2 months ago
Thank you!
The topk
built into PyTorch (torch.topk
) is not differentiable whereas difftopk
is differentiable.
The indices
in torch.topk
are integers and do not have an attached gradient. By returning "indices" in the form of a topk assignment matrix having probabilities, difftopk
enables optimization with respect to "which elements are in the set of topk".
Further, while the values
return value in torch.topk
has attached gradients, in many optimization problems, using smooth rather than just continuous values (as provided in difftopk) also improves training.
If you are interested only in the topk aspects (rather than the TopKCrossEntropyLoss
), see the documentation of difftopk.TopKCrossEntropyLoss
(https://github.com/Felix-Petersen/difftopk?tab=readme-ov-file#difftopknet) or differentiable sorting networks in general (https://github.com/Felix-Petersen/diffsort).
Great job! This is the first time I've seen a differentiable top-k selection in an engineering project so well-implemented. But I have a question for the author: What are the differences between the top-k implementation in this project and the one built into PyTorch (torch.topk )?