Skip to content

A collection of jax functions to help with common machine/deep learning related functionality.

License

Notifications You must be signed in to change notification settings

asmith26/jax_toolkit

Repository files navigation

jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.

Documentation, PyPi

This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.

Installation

pip install jax_toolkit

Or for additional loss function utils:

pip install jax_toolkit[losses_utils]