jump-dev / JuMP.jl

Modeling language for Mathematical Optimization (linear, mixed-integer, conic, semidefinite, nonlinear)
http://jump.dev/JuMP.jl/
Other
2.17k stars 390 forks source link

Fix adjoint of GenericNonlinearExpr #3724

Closed odow closed 2 months ago

odow commented 2 months ago

Closes #3723

So I didn't really think through the implications of this prior.

We kinda punted on the question of whether ScalarNonlinearFunction is <:Real or <:Complex. But I think it's time to pick <:Real, and if we need it, we can introduce something like:

struct ComplexScalarNonlinearFunction
    real::ScalarNonlinearFunction
    imag::ScalarNonlinearFunctionn
end
blegat commented 2 months ago

Note sure about this. Here, you could even have a GenericAffExpr{ComplexF64} inside so it's weird to say it's real. Why not implement these so that it works for the most common operators and error for user-defined operators or even for our own operators for which we haven't done anything yet ? I think erroring is better than unexpected behavior here.

blegat commented 2 months ago

Actually, we could also do

function Base.real(x::GenericNonlinearExpr{V}) where {V}
    return GenericNonlinearExpr{V}(:real, x)
end

and support it in the AD

blegat commented 2 months ago

In view of https://github.com/jump-dev/MathOptInterface.jl/pull/2468, we probably indeed need a way to tell if we have a guarantee that a MOI.ScalarNonlinearFunction will always have real output or if it may be complex. So maybe we also need MOI.ComplexScalarNonlinearFunction which is has the same fields has ScalarNonlinearFunction and behaves exactly the same way except for is_maybe_real. At the JuMP level, we can have only one function type. In the conversion from the JuMP nonlinear expression to the MOI nl function, we need to decide if we can guarantee real-ness. There are three scenarios

In the third case, we should use a MOI.ComplexScalarNonlinearFunction.

We should have a bridge that converts MOI.ComplexScalarNonlinearFunction-in-MOI.EqualTo into MOI.ScalarNonlinearFunction-in-MOI.EqualTo by adding real and imag at the root of the expression tree (actually, maybe we could make it work with SplitComplexEqualToBridge ?).

For the user-defined function, we could add a way for the user to say that it's always real (and maybe it should be the default). Whenever the operator is evaluated, we should check that it's indeed real with a type assert probably then. For complex ones, we could have a ComplexNonlinearOperator or a trait may_be_complex(::NonlinearOperator{F}) that returns true by default (that's probably simpler). If, when we evaluate a nonlinear operator that is said to be real-valued by the user, we see a complex value as output, then we error and tell the user he should implement may_be_complex / or use ComplexNonlinearOperator.

odow commented 2 months ago

I guess the issue is that x' is a very common operation when building models, and people probably expect that it does transpose instead of adjoint.

Let's ignore real and imag for now, and focus on conj. Can we actually define this for arbitrary nonlinear expressions?

blegat commented 2 months ago

We probably replied at the same time. I think in the case we can guarantee that the expression is real, we adjoint is a no-op. Otherwise, we should add it in the expression tree. It's find if the AD errors for now saying adjoint is not a supported operator, it's still an improvement over incorrect behavior. I also expect that we can assert that the expression is real most of the time, especially if we consider that NonlinearOperator is real by default.

odow commented 2 months ago

The other option is NonlinearExpr(:complex, Any[real_part, imag_part]).

But it's also not obvious what to do for sqrt(-x) where x >= 0.

Part of the problem is that we don't have a solver, or an AD system that is capable of Complex. So perhaps it's best just to state that ScalarNonlinearFunction <: Real for now.

I'll check through MOI, but we shouldn't have any users relying on this yet.

Here's an email I sent to @ccoffrin in August:

So this was sufficient

struct ComplexExpr{F<:AbstractJuMPScalar}
    real::F
    imag::F
end

Base.zero(::Type{ComplexExpr{F}}) where {F} = ComplexExpr(zero(F), zero(F))

Base.real(x::ComplexExpr) = x.real

Base.imag(x::ComplexExpr) = x.imag

function Base.:*(x::Complex, y::ComplexExpr)
    return ComplexExpr(
        real(x) * real(y) - imag(x) * imag(y),
        imag(x) * real(y) + real(x) * imag(y),
    )
end

function Base.:*(x::Complex, y::NonlinearExpr)
    return ComplexExpr(real(x) * y, imag(x) * y)
end

Base.:*(y::NonlinearExpr, x::Complex) = x * y

function Base.:+(x::ComplexExpr, y::ComplexExpr)
    return ComplexExpr(real(x) + real(y), imag(x) + imag(y))
end

function Base.:+(x::AbstractJuMPScalar, y::ComplexExpr)
    return ComplexExpr(x + real(y), imag(y))
end

to get this to work

using JuMP, LinearAlgebra
model = Model()
@variable(model, vm[1:2, 1:3])
@variable(model, va[1:2, 1:3])
@variable(model, p[1:3])
@variable(model, q[1:3])
VV_real = vm[1,:] .* vm[2,:]' .* cos.(va[1,:] .- va[2,:]')
VV_imag = vm[1,:] .* vm[2,:]' .* sin.(va[1,:] .- va[2,:]')
Y = [
    1 + 1im  2 + 2im 3 + 3im;
    4 + 4im  5 + 5im 6 + 6im;
    5 + 5im  6 + 6im 7 + 7im;
]
VV = VV_real .+ VV_imag .* im
@constraint(model, p .== real.(diag(Y * VV)))
@constraint(model, q .== imag.(diag(Y * VV)))

So I'm willing to declare this as "will work, given time to implement and test."

We don't have complex-valued nonlinear expressions, but we have complex values with nonlinear components.

We wouldn't support something like cos(::ScalarAffineFunction{<:Complex}) because there's no application need.

blegat commented 2 months ago

My only worry is that users might input complex expressions into nonlinear expression and it would lead here to silent bugs with this PR. Unless the AD throws an error later if it encounters complex values ?

odow commented 2 months ago

Yes, currently AD errors on things it doesn't understand

odow commented 2 months ago

Here's a extension-tests: https://github.com/jump-dev/JuMP.jl/actions/runs/8623134630

blegat commented 2 months ago

Yes but other solvers like Convex.jl might not error. We may be passing incorrect (because of methods of this PR) models with complex expressions to a solver assuming it will error of there are complex expressions. I think it would be safer to error in the conversion from the JuMP expression to the MOI function as well

odow commented 2 months ago

I think it would be safer to error in the conversion from the JuMP expression to the MOI function as well

Do you mean erroring if the input is complex in the ScalarNonlinearFunction constructor of MOI?

You can't create complex nonlinear expressions without manually constructing them: https://github.com/jump-dev/JuMP.jl/blob/05d48766a933c1e4d48fb5e23b28eb4d3dcbace1/src/nlp_expr.jl#L296-L334

blegat commented 2 months ago

The check doesn't seem to be always called:

julia> model = Model();

julia> @variable(model, x in ComplexPlane())
real(x) + imag(x) im

julia> cos(x)
ERROR: Cannot build `GenericNonlinearExpr` because a term is complex-valued: `(real(x) + imag(x) im)::GenericAffExpr{ComplexF64, VariableRef}`
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] _throw_if_not_real(x::GenericAffExpr{ComplexF64, VariableRef})
   @ JuMP ~/.julia/dev/JuMP/src/nlp_expr.jl:307
 [3] cos(x::GenericAffExpr{ComplexF64, VariableRef})
   @ JuMP ~/.julia/dev/JuMP/src/nlp_expr.jl:325
 [4] top-level scope
   @ REPL[23]:1

julia> x^3
(real(x) + imag(x) im) ^ 3

julia> @constraint(model, x^3 == 1)
((real(x) + imag(x) im) ^ 3.0) - 1.0 = 0
odow commented 2 months ago

Ah. I think you found the one method that has a bug: https://github.com/jump-dev/JuMP.jl/blob/05d48766a933c1e4d48fb5e23b28eb4d3dcbace1/src/operators.jl#L210-L218

blegat commented 2 months ago

You can also still do this:

julia> model = Model();

julia> @variable(model, x in ComplexPlane())
real(x) + imag(x) im

julia> GenericNonlinearExpr(:+, x)
+(real(x) + imag(x) im)
odow commented 2 months ago

If people are manually constructing expressions then we don't check anything, correct.

blegat commented 2 months ago

We can add a check right here: https://github.com/jump-dev/JuMP.jl/blob/05d48766a933c1e4d48fb5e23b28eb4d3dcbace1/src/nlp_expr.jl#L95 for arg in args; _throw_if_not_real(arg); end

codecov[bot] commented 2 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 97.03%. Comparing base (617f961) to head (e654125). Report is 1 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #3724 +/- ## ========================================== - Coverage 98.42% 97.03% -1.40% ========================================== Files 43 43 Lines 5825 5287 -538 ========================================== - Hits 5733 5130 -603 - Misses 92 157 +65 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

odow commented 2 months ago

Checking https://github.com/jump-dev/JuMP.jl/actions/runs/8637031676 again before merging