Closed MoFHeka closed 3 weeks ago
I believe that if the matrix format is a pytree (which I think it is), then things should work out of the box?
It would be great if you can check whether things do work out of the box or not.
It would be even more awesome if you can contribute an example on using these sparse matrix formats 😉
matrix format:https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCSR.html#jax.experimental.sparse.BCSR
some feature in deepspeed: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/csrc/adam/fused_adam_frontend.cpp