Closed maedoc closed 6 months ago
Jax supports 16-bit floating point types, https://github.com/google/jax/discussions/6700, would be a look for training e.g. neural ODEs.
There's also the recent BitNet work, in particular the 1.5b quantization looks interesting.
Doesn't work well for the test case I tried, let's set this aside for now.
Jax supports 16-bit floating point types, https://github.com/google/jax/discussions/6700, would be a look for training e.g. neural ODEs.
There's also the recent BitNet work, in particular the 1.5b quantization looks interesting.