Added a custom wrapper for jax.lax.top_k in a new file specialized_fns.py – probably a bad name, but did not know what else to use. Should likely be renamed in the future.
Fixes Issues
32
Unit test coverage
Tests the function over each axis of a random tensor and compares against the corresponding output for jax.lax.top_k.
Description
Added a custom wrapper for jax.lax.top_k in a new file
specialized_fns.py
– probably a bad name, but did not know what else to use. Should likely be renamed in the future.Fixes Issues
32
Unit test coverage
Tests the function over each axis of a random tensor and compares against the corresponding output for jax.lax.top_k.
Known breaking changes/behaviors
N/A