google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

Projected gradient for multidimensional array #594

Open rkruegs123 opened 1 month ago

rkruegs123 commented 1 month ago

Hi -- I have an array that is mxnxn. Each of the m nxn matrices represents a single "frame." I wish to enforce that each nxn matrix/frame is a doubly-stochastic matrix. For a single nxn matrix, I could use projected gradients with the birkhoff projection. However, I would like to apply this constraint to every nxn matrix.

Does anyone know how to best implement this?