stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

added custom wrapper for jax.lax.top_k in new file #38

Closed rohan-mehta-1024 closed 1 year ago

rohan-mehta-1024 commented 1 year ago

Description

Added a custom wrapper for jax.lax.top_k in a new file specialized_fns.py – probably a bad name, but did not know what else to use. Should likely be renamed in the future.

Fixes Issues

32

Unit test coverage

Tests the function over each axis of a random tensor and compares against the corresponding output for jax.lax.top_k.

Known breaking changes/behaviors

N/A