jump-dev / MathOptInterface.jl

A data structure for mathematical optimization problems
http://jump.dev/MathOptInterface.jl/
Other
380 stars 86 forks source link

Debug performance issue in Nonlinear submodule #2496

Closed odow closed 2 months ago

odow commented 2 months ago

See this discourse question:

https://discourse.julialang.org/t/efficient-dual-problems-in-jump-or-avoiding-symbolic-overheads-with-complex-objectives/113953/6?u=odow

9 seconds to compute 12 iterations of a problem with 121 variables doesn't make sense. There must be something going on.

cc @SebKrantz

odow commented 2 months ago

x-ref https://github.com/SebKrantz/OptimalTransportNetworks.jl/blob/main/src/models/model_fixed_duality.jl#L24-L81

Setup

] add https://github.com/SebKrantz/OptimalTransportNetworks.jl JuMP

Code

using OptimalTransportNetworks
function setup()
    param = init_parameters(;
        labor_mobility = false,
        K = 10,
        gamma = 1,
        beta = 1,
        verbose = true,
        N = 1,
        tol = 1e-5,
        cross_good_congestion = false,
        nu = 1,
    ) 
    param, graph = create_graph(param, 11, 11; type = "map")
    param[:Zjn] = fill(0.1, param[:J], param[:N])
    Ni = find_node(graph, 6, 6)
    param[:Zjn][Ni, :] .= 1
    param[:duality] = true
    model, _ = optimal_network(param, graph; verbose = true, return_model = true)
    return model
end

model = setup()
using JuMP
@time optimize!(model)
using ProfileView
@profview optimize!(model)

Result: image

there is no obvious bottleneck. This is just doing a looooot of computation to compute the Hessians.

One reason is that there are 3.5 million nodes in the objective expression!!!

julia> unsafe_backend(model).nlp_data.evaluator.backend.objective.nodes
3569743-element Vector{MathOptInterface.Nonlinear.Node}:
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 2, -1)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 1, 1)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1, 2)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 2)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 2, 4)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 5, 4)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 6)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 7)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 8)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 5, 9)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 10)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 3, 11)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 4, 11)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 11)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 5, 14)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 15)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 16)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VARIABLE, 121, 17)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 5, 17)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 6, 16)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 7, 15)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 8, 14)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 9, 10)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 10, 9)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 11, 8)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 12, 7)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 13, 6)
 ⋮
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 5, 3569717)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VARIABLE, 121, 3569718)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VARIABLE, 120, 3569718)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072898, 3569717)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072899, 3569711)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072900, 3569710)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072901, 3569709)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 3569706)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072902, 3569725)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 3569725)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 4, 3569727)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 9, 3569728)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 3569729)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_UNIVARIATE, 1, 3569730)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 3, 3569731)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072903, 3569732)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_PARAMETER, 420, 3569732)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 2, 3569730)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_CALL_MULTIVARIATE, 5, 3569735)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VARIABLE, 120, 3569736)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VARIABLE, 121, 3569736)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072904, 3569735)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072905, 3569729)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072906, 3569728)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_VALUE, 1072907, 3569727)
 MathOptInterface.Nonlinear.Node(MathOptInterface.Nonlinear.NODE_PARAMETER, 420, 3569705)

@SebKrantz, try introducing some intermediate variables to make the problem larger but sparser:

Instead of:

    # Calculate consumption cj
    @expression(model, cj[j=1:graph.J],
        alpha * (sum(Pjn[j, n]^(1-sigma) for n=1:param.N)^(1/(1-sigma)) / omegaj[j])^(-1/(1+alpha*(rho-1))) * hj1malpha[j]^(-((rho-1)/(1+alpha*(rho-1))))
    )

try

    # Calculate consumption cj
    @variable(model, cj[j=1:graph.J])
    @constraint(model, [j=1:graph.J], cj[j] ==
        alpha * (sum(Pjn[j, n]^(1-sigma) for n=1:param.N)^(1/(1-sigma)) / omegaj[j])^(-1/(1+alpha*(rho-1))) * hj1malpha[j]^(-((rho-1)/(1+alpha*(rho-1))))
    )
odow commented 2 months ago

This is really the same problem as https://github.com/jump-dev/MathOptInterface.jl/issues/2488.

We are not detecting common subexpressions.

SebKrantz commented 2 months ago

@SebKrantz, try introducing some intermediate variables to make the problem larger but sparser:

Instead of:

    # Calculate consumption cj
    @expression(model, cj[j=1:graph.J],
        alpha * (sum(Pjn[j, n]^(1-sigma) for n=1:param.N)^(1/(1-sigma)) / omegaj[j])^(-1/(1+alpha*(rho-1))) * hj1malpha[j]^(-((rho-1)/(1+alpha*(rho-1))))
    )

try

    # Calculate consumption cj
    @variable(model, cj[j=1:graph.J])
    @constraint(model, [j=1:graph.J], cj[j] ==
        alpha * (sum(Pjn[j, n]^(1-sigma) for n=1:param.N)^(1/(1-sigma)) / omegaj[j])^(-1/(1+alpha*(rho-1))) * hj1malpha[j]^(-((rho-1)/(1+alpha*(rho-1))))
    )


Thanks @odow! However, there already exists a primal version of the problem, which is larger and also a lo faster (0.28 seconds instead of 9). My reason for employing the dual solution is to have as few optimisation variables as possible to be able to efficiently solve really large networks (with thousands of nodes and edges). I am thus thinking: if the symbolics cannot be fixed to efficiently process the sequence of expressions used to create the complex objective, could there be a way in JuMP to have an ordinary (black box) objective function? It should still be possible to numerically detect sparsity.

odow commented 2 months ago

The issue is not the number of variables, but now long it takes JuMP to compute derivatives.

The AD engine in JuMP is really designed for lots of simpler sparse expressions, instead of one very complicated expression. in addition, at the moment we don't exploit common subexpressions, so if you have cj[j] in multiple places, we copy-paste the entire expression, instead of caching it somewhere.

You can add black-box functions to JuMP: https://jump.dev/JuMP.jl/stable/manual/nonlinear/#jump_user_defined_operators

But you need to provide derivatives, or we will use ForwardDiff (which won't scale to high dimensional inputs).

SebKrantz commented 2 months ago

Ok, thanks! It does not work however in this case making cj a constraint:

EXIT: Invalid number in NLP function or derivative detected.

ERROR: Solver returned with error code INVALID_MODEL.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] optimal_network(param::Dict{Any, Any}, graph::Dict{Symbol, Any}; I0::Nothing, Il::Nothing, Iu::Nothing, verbose::Bool, return_model::Int64)
   @ OptimalTransportNetworks ~/Documents/Julia/OptimalTransportNetworks.jl/src/main/optimal_network.jl:167
 [3] optimal_network(param::Dict{Any, Any}, graph::Dict{Symbol, Any})
   @ OptimalTransportNetworks ~/Documents/Julia/OptimalTransportNetworks.jl/src/main/optimal_network.jl:30
 [4] top-level scope
   @ ~/Documents/Julia/OptimalTransportNetworks.jl/examples/example01.jl:32

In any case, I am doubtful that such hacks will eventually make the dual problem faster than the primal one. So I guess in this case the solution of @gdalle (or in general writing the objective as an ordinary Julia function and using other differentiation tools) may be the way forward, or the dual solution is just not feasible in Julia at this point.

Do keep me posted though if there are improvements to the AD backend or in MathOptSymbolicAD.jl, and I'll be happy to reconsider the dual solution.

gdalle commented 2 months ago

So I guess in this case the solution of @gdalle (or in general writing the objective as an ordinary Julia function and using other differentiation tools) may be the way forward, or the dual solution is just not feasible in Julia at this point.

We're working on a revamp of SparseConnectivityTracer.jl + DifferentiationInterface.jl for a conference submission so it won't be right now, but I hope to be able to help you soon enough!

You can try Symbolics.jl + SparseDiffTools.jl + Optimization.jl, which would be the old combo for sparse hessians, but I can't guarantee it will work better for your specific code (although it is more battle tested)

odow commented 2 months ago

It does not work however in this case making cj a constraint:

This is likely because the starting value is 0 and you have a 1 / cj.

It's annoying that there are a number of "tricks" required to get this to work.

odow commented 2 months ago

Closing this in favor of #2488. There's nothing we need to do here, other than re-introduce common subexpressions.