jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
695 stars 40 forks source link

topk #23

Closed mglisse closed 3 years ago

mglisse commented 3 years ago

Hello,

I notice that pytorch and tensorflow both have a topk function, which can be faster than a full sort if k is small. With numpy, it can be emulated using (arg)partition. I think it would be convenient if eagerpy provided a wrapper for this, I have 2 unrelated pieces of code where I would be likely to use it.

(of course it is disappointing that tensorflow does not seem to support the argument dim, but doing some transposition around the call is easy enough)