ins-amu / vbjax

A nascent Jax-based package for virtual brain modeling.
Apache License 2.0
7 stars 2 forks source link

Try out lower precision types for neural networks #65

Closed maedoc closed 6 months ago

maedoc commented 6 months ago

Jax supports 16-bit floating point types, https://github.com/google/jax/discussions/6700, would be a look for training e.g. neural ODEs.

There's also the recent BitNet work, in particular the 1.5b quantization looks interesting.

maedoc commented 6 months ago

Doesn't work well for the test case I tried, let's set this aside for now.