stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Worth adding top_k function? #32

Closed rohan-mehta-1024 closed 9 months ago

rohan-mehta-1024 commented 10 months ago

It seems like adding a haliax impelementation for jax.lax.top_k would be a good idea, especially as it's being primarily used for language modeling by levanter. If this makes for a good first issue, I could maybe try my hand at it? The only thing that seems like it might be challenging is that you would have to resize the axis along which you are taking the top_k.

dlwh commented 10 months ago

sure, i'd take a PR for that!

rohan-mehta-1024 commented 10 months ago

Just to talk things through, here's where I'm at right now:

  1. My first thought was to try and wrap it using the existing wrapping functionality but this doesn't seem like it will work, because jax.lax.top_k returns two arrays, both an array of the top_k elements and their indices. Also, it doesn't fit neatly into a reduction or axiswise operation, since the axes remain the same, one is just resized. So it seems like it will need a custom wrapper, would you also agree?
  2. One other thing is that jax.lax.top_k doesn't accept an axis to perform the operation on, it always chooses the last axis. Would the proper thing to do here be to allow the haliax version to accept an axis, and then transpose the underlying array such that the inputted axis becomes the last axis, and then perform the operation, and then reshape it back?
  3. Structurally, where does it belong in the codebase? Lots of similar-ish functions, e.g., hax.argmax are defined in the init file but none of these have custom wrappers. Would putting it in core.py (or somewhere else) make more sense?
dlwh commented 10 months ago

(1) I think you'll have to use your own wrapper, yeah.

(2) yep, that would be great!

(3) maybe make a new file? I should move most of the things in init anyway