asmith26 / jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.
https://asmith26.github.io/jax_toolkit/
Apache License 2.0
5 stars 0 forks source link

Consider adding more `vmap` for batching functionality #120

Open asmith26 opened 3 years ago

asmith26 commented 3 years ago

Following the work completed in https://github.com/asmith26/jax_toolkit/pull/119, it could be useful to add the vmap functionality to additional losses and metrics.