pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
297 stars 91 forks source link

ENH: Implement COLA library rewrites for linear algebra functions #573

Open jessegrabowski opened 6 months ago

jessegrabowski commented 6 months ago

Description

This paper and this library describe and implement a number of linear algebra simplifications that we can implement as graph rewrites. This issue is both a tracker for implementing these rewrites, and a discussion for how to handle them.

Consider the following graph:

x = pt.dmatrix('x')
y = pt.diagonal(x)
z = pt.linalg.inv(y)

If we could promise at rewrite time that y is diagonal, we could re-write the last operation as z = 1/y, exploiting the structure of the diagonal matrix. Other non-trivial examples exist, for example:

x = pt.dmatrix('x')
y = pt.eye(3)
z = pt.kron(x, y)
z_inv = pt.linalg.inv(z)

If we could promise at rewrite time that z is block diagonal, we could rewrite z_inv = pt.kron(pt.linalg.inv(x), y), which is a much faster operation (since x is 3x smaller than z).

The linked paper and library list a huge number of such shortcut computations. The following is a list version of Table 1 from the paper. Under each function are the types of matrices for which a rewrite rule exists to accelerate the function. ~It would be nice to collaborative update this list with links to the COLA library where the relevant rewrite is implemented:~

Thanks to @tanish1729 for compiling a list of links to relevant rewrites. Missing links indicate no direct implementation

In addition, potential re-writes could also be applied to Topelitz and Circulant matrices, although these are not covered by COLA.

The wrinkle to all this is that we would need more information about matrices as they enter and exit Ops. Right now, we're using tags to accomplish rewrites like this, see for example #303 and #459 . Some of these rewrites might be possible to do via inference. For example, a pt.linalg.block_diag Op always returns a block diagonal matrix, as does pt.kron(pt.eye, A). pt.diagonal always returns a diagonal matrix, as does pt.eye; pt.linalg.cholesky always returns a triangular matrix, etc. Other potential types, like block, positive, definite, psd, topelitz, circulant, etc, would be less trivial to automatically detect.

The other issue is that as these type tags proliferate, we become more and more locked into a somewhat hack-y system for marking things. Perhaps putting some thought into how to handle this now will save some refactoring headaches down the road?

ricardoV94 commented 6 months ago

I think the easiest short term path is to add a feature (like the shape feature) that does/keeps track of this more specialized type inference.

For instance when you replace a variable by another you can assume what you inferred for the old variable holds for the new.

When a new variable is added you can compute new stuff. Like x * 5, where x is positive/diagonal is also positive/diagonal. If it's positive it's not negative and so on.

Some type info needs to be seeded by the user initially, as they won't emerge naturally from realistic graphs. Right now we use tags for that but maybe we could add a dummy pass through Op, like that is ultimately discarded but feeds initial type info. User could say x = pt.hint(x, "positive") or pt.hint(x, x>0).

Or we could introspect asserts like pt.assert(x, x>0) but we probably don't want to keep most of these asserts in the final graph.

jessegrabowski commented 6 months ago

This is how sympy handles it, with their assumptions system. For example:

import sympy as sp
x = sp.Symbol('x', positive=True)
x._assumptions0

Gives a list of implications of x being positive:

{'positive': True,
 'extended_nonnegative': True,
 'commutative': True,
 'nonpositive': False,
 'extended_positive': True,
 'nonnegative': True,
 'nonzero': True,
 'imaginary': False,
 'infinite': False,
 'extended_nonpositive': False,
 'real': True,
 'extended_negative': False,
 'complex': True,
 'extended_nonzero': True,
 'zero': False,
 'negative': False,
 'hermitian': True,
 'finite': True,
 'extended_real': True}

All of these have implications for their algebraic simplifications. For example, it will refuse to simplify (x ** 0.5) ** 2 to x if it doesn't know that x is strictly positive.

It's actually a nice system but I read they were trying to revamp it, so I'm curious what they think the big failures of it are. It might be nice to try to set up a call with their devs to talk about their lessons from going this route before we commit to it.

I also think something like this is really where egg rewrites can shine. Especially if we're just dealing with matrices, we shouldn't have to reason about shapes too much (although Blockwise somewhat ruins that...)

ricardoV94 commented 6 months ago

It might be nice to try to set up a call with their devs to talk about their lessons from going this route before we commit to it.

Definitely

tanish1729 commented 3 months ago

hey @jessegrabowski! are there any more related PRs that i could try working on for this project?

otherwise, from what i understand, this is mainly focused on graph rewriting? and we want to be using the computations suggested in the paper in pytensor

jessegrabowski commented 3 months ago

Yeah basically we want to introduce simplifications related to compositions of linear algebra operations via rewrites. The COLA paper suggests a big pile, so that's what I suggested as a "roadmap".

The only active one right now is #622 , you can look here for a rough template of a rewrite looks like. Maybe a "simple" rewrite to get your feet wet would be to change det(diag(x)) to prod(x). The rewrite should look for a det Op, then check if the parent is a diag op. If it finds that, it should return the product of the input to diag.

Obviously there are other ways we could look for a diagonal matrix to arise (for example, eye * x), but that would be a good place to get the ball rolling.

tanish1729 commented 3 months ago

alright i understand. apart from the docs here, are there any other places to read about graph rewriting from. i am looking for something to explain the basic relations which are being used in the code

tanish1729 commented 3 months ago

hey @jessegrabowski i went through some of the rewrites and this one here from the docs is way more straightforward than the one which you shared. i can make out that essentially we are just iterating over the graph, checking the operand and operator at each step, and then making the changes, but some of the syntax used in the graph seems confusing to me.

jessegrabowski commented 3 months ago

Maybe this video by @ricardoV94 would be nice to check out. He goes though making a rewrite live and explains a lot about how pytensor works:

https://www.youtube.com/watch?v=SqXUely5FpQ

ricardoV94 commented 3 months ago

This may also be useful: https://gist.github.com/ricardoV94/deaf7b18660588faac1c30bc5b31c011

tanish1729 commented 3 months ago

these were both really helpful thanks a lot! i also went through the docs on graph structures and feel much more confident in understanding the task now. are there any suggestions you might have to break the project down into a bunch of smaller ideas for writing my proposal or just any tips in general?

jessegrabowski commented 3 months ago

If you search on the discourse you'll find other applications, that should give you a sense of what the structure is meant to be. I'm not sure what resources you get on the Google side.

For breaking down the project, I'd start by looking table 1 in the COLA paper, and trying to classify the optimizations into "easy", "medium", "hard", thinking about:

  1. How "tricky" is the re-write, as such? For example, the det(diag(x)) -> prod(x) is conceptually quite simple, but converting a matrix inverse into Woodbury form is more complex. Eigen-problems are more complex than block-diag problems, and so on.
  2. How difficult is it to detect that a certain rewrite is going to be valid? For example, there are a number of rewrites that apply to triangular matrices -- how can these arise? Suppose we bump into inv(x), and we want to know if we can apply a triangular rewrite. We know that the output of cholesky(x) is always triangular, so we can look for that. Are there other ways it can arise? Same goes for diag -- we could have diag(x), but also eye(x.shape[0]) * x, but some psycho user could come up with more exotic things.
  3. Is the rewrite case even valid in Pytensor? Our support for eigenproblems is spotty -- we have Eig and Eigh and SVD, but only Eigh has gradients. We also don't have robust support for convolutions -- it's quite out of date, and there's no jax/numba linker for it. These are things to note.
  4. How useful is a given rewrite in the context of other rewrites? The COLA paper notes that some rewrites are compositions of others, so having those "atomic" rewrites is quite valuable.

... and so on.

Based on that analysis, you can sort of form a rough roadmap of what you think will be realistic to tackle over the course of your GSoC. It will also give you some nice material for a discussion/introduction session, and highlight areas you might need to learn more about (for example different kinds of matrices/decompositions/solvers/etc that maybe you've never seen before)

tanish1729 commented 3 months ago

great! i have started working on a proposal; hopefully i can share a first draft by tomorrow so I can make revisions to it before the final deadline.

in this issue, there is a comment about how assumptions are handled in sympy. As this project relies a lot on assumptions too regarding nature of matrices and operations, is there an already existing system (from what I can tell, this has been done using tags) or should I also present an idea for dealing with assumptions (maybe similar to something like sympy)

another thing i wanted to confirm is the duration of the project. the site states that all these projects will be 350hrs in duration so I am going with that itself unless any change is required?

jessegrabowski commented 3 months ago

I don't think we should charge headfirst into a big assumptions system, we should just see how far we can get with pure graph introspection. I think there's already a lot we can do with this simpler approach.

Regarding time, I guess I have no idea. It's up to you how big of a commitment you want to make, but certainly this is going to be a big effort.

tanish1729 commented 3 months ago

another quick question, is the list of shortcuts an exhaustive one? the table 1 in the paper has a lot more operations but have they been excluded on purpose because they are not that relevant here. if not, i will also make a new list the same way with more operations and type of matrices, along with a link to the CoLa rewrite

tanish1729 commented 3 months ago

hello @jessegrabowski @ricardoV94! should i share my proposal on the discourse for feedback or just privately with one of the mentors?

jessegrabowski commented 3 months ago

I think discourse makes the most sense.

Sorry I didn't see your other message last night. I think being really complete would be nice, and we can pair it down from there if we agree it's too repetitive/messy.

tanish1729 commented 3 months ago

cool cool. im gonna share it there then!

tanish1729 commented 3 months ago

hey! did you guys get time to take a look at my proposal. I have exams in my college from day after so I'd really appreciate it if I could make any updates today itself.

On Fri, 29 Mar, 2024, 20:08 Jesse Grabowski, @.***> wrote:

I think discourse makes the most sense.

Sorry I didn't see your other message last night. I think being really complete would be nice, and we can pair it down from there if we agree it's too repetitive/messy.

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/issues/573#issuecomment-2027331320, or unsubscribe https://github.com/notifications/unsubscribe-auth/AV6XKQ5NIBJBGNGWMR75LGLY2VVEPAVCNFSM6AAAAABBO4S2Z6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRXGMZTCMZSGA . You are receiving this because you commented.Message ID: @.***>

ricardoV94 commented 3 months ago

hey! did you guys get time to take a look at my proposal. I have exams in my college from day after so I'd really appreciate it if I could make any updates today itself.

On Fri, 29 Mar, 2024, 20:08 Jesse Grabowski, @.***> wrote:

I think discourse makes the most sense.

Sorry I didn't see your other message last night. I think being really complete would be nice, and we can pair it down from there if we agree it's too repetitive/messy.

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pytensor/issues/573#issuecomment-2027331320, or unsubscribe https://github.com/notifications/unsubscribe-auth/AV6XKQ5NIBJBGNGWMR75LGLY2VVEPAVCNFSM6AAAAABBO4S2Z6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRXGMZTCMZSGA . You are receiving this because you commented.Message ID: @.***>

Looks great. Good luck with your exams!

tanish1729 commented 3 months ago

thanks! is there anything you guys found lacking or should i submit this as-is?

ricardoV94 commented 3 months ago

thanks! is there anything you guys found lacking or should i submit this as-is?

We think it's good as is

tanish1729 commented 3 months ago

hey @jessegrabowski i was trying to make the list along with link to the corresponding CoLa rewrites but it seems like all the ones listed in the paper are not directly there in code? the eigenvalues file for example only has implementations for triangular, diagonal and identity only despite listing a lot more in the paper. this is case for other functions too. can you help me out with that. i can still update my proposal with this since there's still time left

jessegrabowski commented 3 months ago

COLA doesn't implement everything I don't think. If you can't find it just make a note that it doesn't seem to be implemented and move on. You need to hurry up and submit!

tanish1729 commented 3 months ago

I have made a list here. Can this be a new PR which I can add in my proposal or should i just include the list itself.

jessegrabowski commented 3 months ago

Include the list in your proposal and I'll edit this post to include the links (crediting you in the first post of course)