Closed rohan-mehta-1024 closed 9 months ago
sure, i'd take a PR for that!
Just to talk things through, here's where I'm at right now:
jax.lax.top_k
returns two arrays, both an array of the top_k elements and their indices. Also, it doesn't fit neatly into a reduction or axiswise operation, since the axes remain the same, one is just resized. So it seems like it will need a custom wrapper, would you also agree?jax.lax.top_k
doesn't accept an axis to perform the operation on, it always chooses the last axis. Would the proper thing to do here be to allow the haliax version to accept an axis, and then transpose the underlying array such that the inputted axis becomes the last axis, and then perform the operation, and then reshape it back? (1) I think you'll have to use your own wrapper, yeah.
(2) yep, that would be great!
(3) maybe make a new file? I should move most of the things in init anyway
It seems like adding a haliax impelementation for jax.lax.top_k would be a good idea, especially as it's being primarily used for language modeling by levanter. If this makes for a good first issue, I could maybe try my hand at it? The only thing that seems like it might be challenging is that you would have to resize the axis along which you are taking the top_k.