google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Support for CSR format sparse matrix in optimizer? #994

Closed MoFHeka closed 3 weeks ago

MoFHeka commented 3 months ago

matrix format:https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCSR.html#jax.experimental.sparse.BCSR

some feature in deepspeed: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/csrc/adam/fused_adam_frontend.cpp

fabianp commented 3 months ago

I believe that if the matrix format is a pytree (which I think it is), then things should work out of the box?

It would be great if you can check whether things do work out of the box or not.

It would be even more awesome if you can contribute an example on using these sparse matrix formats 😉