lezcano / geotorch

Constrained optimization toolkit for PyTorch
https://geotorch.readthedocs.io
MIT License
657 stars 34 forks source link

can we expect geojax? #29

Closed renjithravindran closed 2 years ago

renjithravindran commented 2 years ago

hello there, by any chance are you looking to port geotorch to Jax? I am trying to do tucker factorization of large sparse tensors, and wanted a pluaggable orthogonality constraint.

thanks renjith

lezcano commented 2 years ago

Alas, I don't have the time to do that. Now, if anyone else were to take on that project, I'd be very happy to answer questions about geotorch

renjithravindran commented 2 years ago

the math behind your work is above me, may be i can see the torch implementation of orthogonality and try to do the same with jax. What are the advantages of using your technique over say adding a orthogonality regulariser (X'X=I) to the loss function?

lezcano commented 2 years ago

With a regulariser you don't get an orthogonal matrix, with this approach you do.

For the code, just have a look at the SO class.

renjithravindran commented 2 years ago

I see that SO is for square orthogonal, should orthogonal matrices be always square? Factor matrices from an SVD are orthogonal, but with truncated SVD they are not square. So does that mean the factor matrices of truncated SVD are not orthogonal?

lezcano commented 2 years ago

I think you are referring to matrices with orthonormal columns. For that, have a look at the Stiefel class.

renjithravindran commented 2 years ago

yes rectangular orthogonality. the code looks simple for a blind port, however I have no clue where this plays within the optimization process. Please give a me an idea as to how i might add this to a general optimization process, if i were to blind port the stiefel class. you may also give pointers to your paper if it might give me some clue.

Thanks a lot! renjith

lezcano commented 2 years ago

GeoTorch is nothing but a set of functions that map rectangular matrices into matrices with orthogonal columns or whatever you want. To find these maps, you can have a look at the cited papers or try to find them yourself in the literature.

Then, the constraint is maintained because you update the unconstrained matrix by differentiating through the map (autograd does this). Have a look at the Symmetric or Skew classes for the simplest examples.

renjithravindran commented 2 years ago

Great! As of now it is not certain i need orthogonality constraint for my work. If necessary i will try my luck at porting.

Thanks a lot, renjith