asmith26 / jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.
https://asmith26.github.io/jax_toolkit/
Apache License 2.0
5 stars 0 forks source link

jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.

Documentation, PyPi

This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.

Installation

pip install jax_toolkit

Or for additional loss function utils:

pip install jax_toolkit[losses_utils]