stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

top_k returns indices in addition to values #61

Closed blahBlahhhJ closed 5 months ago

blahBlahhhJ commented 5 months ago

Couldn't find a easy way to maintain backward compatibility.

Also, let me know if the tests make sense.

dlwh commented 5 months ago

i don't think it's worth maintaining backward compat here, but we could do a flag like return_indices=true and then use typing.overload, but it's not worth the work.

cc @rohan-mehta-1024 we're gonna break compatibility here