google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

Expression tree API like CVXPY #558

Open carlosgmartin opened 10 months ago

carlosgmartin commented 10 months ago

Does jaxopt have an expression tree API like CVXPY does?

If not, would it be possible to create one? This would make it easier to build up problems.

vroulet commented 10 months ago

Hello @carlosgmartin, Wouldn't jaxprs be what you are looking for? JAX traces the whole computational graph so that you could directly work on it. I may not have got what you are looking for. Feel free to detail more your objective.

carlosgmartin commented 10 months ago

@vroulet I meant something like the example shown in the second link (more examples here). Here's another example:

def find_nash_equilibrium(u):
    """Find the Nash equilibrium of a two-player zero-sum normal-form game.
    u is the payoff matrix for the row player."""
    x = cp.Variable(u.shape[0])
    v = cp.Variable()
    objective = cp.Maximize(v)
    constraints = [
        v <= x @ u,
        x >= 0,
        x.sum() == 1,
    ]
    problem = cp.Problem(objective, constraints)
    result = problem.solve()
    return {"v": v.value, "x": x.value, "y": constraints[0].dual_value}

That is, letting users create variables and build up expressions from them to create objectives and constraints for a desired problem. This makes it easier for users to write linear/quadratic programs than manually fiddling with the A, b, G, h, Q, c arrays.

vroulet commented 10 months ago

So no, we haven't that now and we were not planning on adding it. Wouldn't a package like https://github.com/cvxgrp/cvxpylayers be a good starting point?