jump-dev / JuMP.jl

Modeling language for Mathematical Optimization (linear, mixed-integer, conic, semidefinite, nonlinear)
http://jump.dev/JuMP.jl/
Other
2.22k stars 393 forks source link

DNMY: Enzyme extension #3712

Closed michel2323 closed 6 months ago

michel2323 commented 6 months ago

I made this PR initially to Enzyme https://github.com/EnzymeAD/Enzyme.jl/pull/1337 , but @wsmoses recommended to make it an extension of JuMP. Let me know if this works and I can add this as a test.

This extends JuMP and allows a user in JuMP to differentiate an external function using Enzyme.

Use case:

using Ipopt
using JuMP
using Enzyme

# Rosenbrock
rosenbrock(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

model = Model(Ipopt.Optimizer)
op_rosenbrock =  model[:op_rosenbrock] = add_nonlinear_operator(model, 2, rosenbrock; name=:op_rosenbrock) 
@variable(model, x[1:2])

@objective(model, Min, op_rosenbrock(x[1],x[2]))

optimize!(model)
codecov[bot] commented 6 months ago

Codecov Report

Attention: Patch coverage is 0% with 44 lines in your changes are missing coverage. Please review.

Project coverage is 97.62%. Comparing base (a15daaa) to head (782b3b9).

Files Patch % Lines
ext/JuMPEnzymeExt.jl 0.00% 44 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #3712 +/- ## ========================================== - Coverage 98.37% 97.62% -0.75% ========================================== Files 43 44 +1 Lines 5736 5780 +44 ========================================== Hits 5643 5643 - Misses 93 137 +44 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

odow commented 6 months ago

Okay, this needs some discussion, likely at a monthly developer call.

Let's also put aside the exact syntax. Instead of pirating a method like this, we'd need to add some sort of type or flag for people to opt-in, but that is a small issue that can be resolved.

I am somewhat in favor of this, but given the experience of https://github.com/jump-dev/JuMP.jl/pull/3707, I think we should be very careful about adding this.

Particularly relevant is this discussion: https://github.com/jump-dev/JuMP.jl/pull/3413#issuecomment-1603425251

I would be strongly in favor of making a requirement that new extensions must have a 1.0 release, and have no plans for a 2.0 release. This would rule out anything that has moved from v1.0.0 to v5.67.2 in a short time period, and it would rule out Enzyme, which is on v0.11

Another option is that we add a page to the documentation which shows how to construct the appropriate gradient and hessian oracles, but we don't add this to JuMP, either directly or as an extension.

It's also worth evaluating the cost on compilation times for the tests and documentation if we add this. Enzyme is pretty heavy.

odow commented 6 months ago

Also, using your code I get:

julia> using Enzyme

julia> function jump_operator(f::Function)
           @inline function f!(y, x...)
               y[1] = f(x...)
           end
           function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
               y = zeros(T,1)
               ry = ones(T,1)
               rx = ntuple(N) do i
                   Active(x[i])
               end
               g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
               return nothing
           end

           function gradient_deferred!(g, y, ry, rx...)
               g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
               return nothing
           end

           function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
               y = zeros(T,1)
               dy = ntuple(N) do i
                   ones(1)
               end
               g = zeros(T,N)
               dg = ntuple(N) do i
                   zeros(T,N)
               end
               ry = ones(1)
               dry = ntuple(N) do i
                   zeros(T,1)
               end
               rx = ntuple(N) do i
                   Active(x[i])
               end

               args = ntuple(N) do i
                   drx = ntuple(N) do j
                       if i == j
                           Active(one(T))
                       else
                           Active(zero(T))
                       end
                   end
                   BatchDuplicated(rx[i], drx)
               end
               autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...)
               for i in 1:N
                   for j in 1:N
                       if i <= j
                           H[j,i] = dg[j][i]
                       end
                   end
               end
               return nothing
           end

           return gradient!, hessian!
       end
jump_operator (generic function with 1 method)

julia> foo(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
foo (generic function with 1 method)

julia> ∇foo, ∇²foo = jump_operator(foo)
(var"#gradient!#9"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo)), var"#hessian!#12"{var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}}(var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo))))

julia> N = 3
3

julia> x = rand(N)
3-element Vector{Float64}:
 0.23712902725864782
 0.6699243680780806
 0.530669076854107

julia> g = zeros(N)
3-element Vector{Float64}:
 0.0
 0.0
 0.0

julia> H = zeros(N, N)
3×3 Matrix{Float64}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

julia> foo(x...)
1.593666919840647

julia> ∇foo(g, x...)

julia> ∇²foo(H, x...)
ERROR: Attempting to call an indirect active function whose runtime value is inactive:
Backtrace

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
 [3] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009
 [4] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179
 [5] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056 [inlined]
  [3] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009 [inlined]
  [4] runtime_generic_augfwd
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179 [inlined]
  [5] runtime_generic_augfwd
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0 [inlined]
  [6] fwddiffe3julia_runtime_generic_augfwd_3727_inner_1wrap
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
  [8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
  [9] (::Enzyme.Compiler.ForwardModeThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5001
 [10] runtime_generic_fwd(activity::Type{…}, width::Val{…}, RT::Val{…}, f::typeof(Enzyme.Compiler.runtime_generic_augfwd), df::Nothing, df_2::Nothing, df_3::Nothing, primal_1::Type{…}, shadow_1_1::Nothing, shadow_1_2::Nothing, shadow_1_3::Nothing, primal_2::Val{…}, shadow_2_1::Nothing, shadow_2_2::Nothing, shadow_2_3::Nothing, primal_3::Val{…}, shadow_3_1::Nothing, shadow_3_2::Nothing, shadow_3_3::Nothing, primal_4::Val{…}, shadow_4_1::Nothing, shadow_4_2::Nothing, shadow_4_3::Nothing, primal_5::typeof(foo), shadow_5_1::Nothing, shadow_5_2::Nothing, shadow_5_3::Nothing, primal_6::Nothing, shadow_6_1::Nothing, shadow_6_2::Nothing, shadow_6_3::Nothing, primal_7::Float64, shadow_7_1::Float64, shadow_7_2::Float64, shadow_7_3::Float64, primal_8::Base.RefValue{…}, shadow_8_1::Base.RefValue{…}, shadow_8_2::Base.RefValue{…}, shadow_8_3::Base.RefValue{…}, primal_9::Float64, shadow_9_1::Float64, shadow_9_2::Float64, shadow_9_3::Float64, primal_10::Base.RefValue{…}, shadow_10_1::Base.RefValue{…}, shadow_10_2::Base.RefValue{…}, shadow_10_3::Base.RefValue{…}, primal_11::Float64, shadow_11_1::Float64, shadow_11_2::Float64, shadow_11_3::Float64, primal_12::Base.RefValue{…}, shadow_12_1::Base.RefValue{…}, shadow_12_2::Base.RefValue{…}, shadow_12_3::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:116
 [11] f!
odow commented 6 months ago

Here's some code I had when experimenting with this:

abstract type AbstractADOperator end

#=
    Enzyme
=#

import Enzyme

struct ADOperatorEnzyme <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorEnzyme)
    @inline f!(y, x::Vararg{T,N}) where {T,N} = (y[1] = f(x...))
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        g .= Enzyme.autodiff(
            Enzyme.Reverse,
            f!,
            Enzyme.Const,
            Enzyme.Duplicated(zeros(T, 1), ones(T, 1)),
            Enzyme.Active.(x)...,
        )[1][2:end]
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        dg = ntuple(_ -> zeros(T, N), N)
        args = ntuple(N) do i
            return Enzyme.BatchDuplicated(
                Enzyme.Active(x[i]),
                ntuple(j -> Enzyme.Active(T(i == j)), N),
            )
        end
        function gradient_deferred!(g, y, ry, rx...)
            g .= Enzyme.autodiff_deferred(
                Enzyme.Reverse,
                f!,
                Enzyme.Const,
                Enzyme.Duplicated(y, ry),
                rx...,
            )[1][2:end]
            return
        end
        Enzyme.autodiff(
            Enzyme.Forward,
            gradient_deferred!,
            Enzyme.Const,
            Enzyme.BatchDuplicated(zeros(T, N), dg),
            Enzyme.BatchDuplicated(zeros(T, 1), ntuple(_ -> ones(T, 1), N)),
            Enzyme.BatchDuplicated(ones(T, 1), ntuple(_ -> zeros(T, 1), N)),
            args...,
        )
        for j in 1:N, i in 1:j
            H[j, i] = dg[j][i]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    ForwardDiff
=#

import ForwardDiff

struct ADOperatorForwardDiff <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorForwardDiff)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        ForwardDiff.gradient!(g, y -> f(y...), collect(x))
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        h = ForwardDiff.hessian(y -> f(y...), collect(x))
        for i in 1:N, j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    Examples
=#

import LinearAlgebra
using Test

function example_logsumexp()
    f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
    ∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
    function ∇²f(H, x...)
        y = collect(x)
        g = exp.(y) / sum(exp.(y))
        h = LinearAlgebra.diagm(g) - g * g'
        for i in 1:length(y), j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f, ∇²f
end

function example_rosenbrock()
    f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
    function ∇f(g, x, y)
        g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
        g[2] = 200 * (y - x^2)
        return
    end
    function ∇²f(H, x, y)
        H[1, 1] = 2 + 1200 * x^2 - 400 * y
        H[2, 1] = -400 * x
        H[2, 2] = 200
        return
    end
    return f, ∇f, ∇²f
end

function test_example(example, N, config::AbstractADOperator)
    true_f, true_∇f, true_∇²f = example()
    f, ∇f, ∇²f = create_operator(true_f, config)
    x = rand(N)
    y = f(x...)
    true_y = true_f(x...)
    @test isapprox(y, true_y)
    g, true_g = zeros(N), zeros(N)
    ∇f(g, x...)
    true_∇f(true_g, x...)
    @test isapprox(g, true_g)
    H, true_H = zeros(N, N), zeros(N, N)
    ∇²f(H, x...)
    true_∇²f(true_H, x...)
    @test isapprox(H, true_H)
    return
end

@testset "Examples" begin
    for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
        for (example, N) in (
            example_rosenbrock => 2,
            example_logsumexp => 3,
            example_logsumexp => 20,
        )
            @testset "$example - $N - $config" begin
                test_example(example, N, config)
            end
        end
    end
end

Running yields

Examples                                           |   16      2     18  11.3s
  example_rosenbrock - 2 - ADOperatorForwardDiff() |    3             3   1.0s
  example_logsumexp - 3 - ADOperatorForwardDiff()  |    3             3   1.1s
  example_logsumexp - 20 - ADOperatorForwardDiff() |    3             3   1.2s
  example_rosenbrock - 2 - ADOperatorEnzyme()      |    3             3   0.4s
  example_logsumexp - 3 - ADOperatorEnzyme()       |    2      1      3   1.0s
  example_logsumexp - 20 - ADOperatorEnzyme()      |    2      1      3   6.6s
ERROR: LoadError: Some tests did not pass: 16 passed, 0 failed, 2 errored, 0 broken.
odow commented 6 months ago

Okay, I've tightened things up considerably, and got rid of the Hessian error:

abstract type AbstractADOperator end

#=
    Enzyme
=#

import Enzyme

struct ADOperatorEnzyme <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorEnzyme)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
        hess = Enzyme.autodiff(
            Enzyme.Forward,
            (x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
            Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
        )[1]
        for j in 1:N, i in 1:j
            H[j, i] = hess[j][i]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    ForwardDiff
=#

import ForwardDiff

struct ADOperatorForwardDiff <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorForwardDiff)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        ForwardDiff.gradient!(g, y -> f(y...), collect(x))
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        h = ForwardDiff.hessian(y -> f(y...), collect(x))
        for i in 1:N, j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    Examples
=#

import LinearAlgebra
using Test

function example_logsumexp()
    f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
    ∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
    function ∇²f(H, x...)
        y = collect(x)
        g = exp.(y) / sum(exp.(y))
        h = LinearAlgebra.diagm(g) - g * g'
        for i in 1:length(y), j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f, ∇²f
end

function example_rosenbrock()
    f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
    function ∇f(g, x, y)
        g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
        g[2] = 200 * (y - x^2)
        return
    end
    function ∇²f(H, x, y)
        H[1, 1] = 2 + 1200 * x^2 - 400 * y
        H[2, 1] = -400 * x
        H[2, 2] = 200
        return
    end
    return f, ∇f, ∇²f
end

function test_example(example, N, config::AbstractADOperator)
    true_f, true_∇f, true_∇²f = example()
    f, ∇f, ∇²f = create_operator(true_f, config)
    x = rand(N)
    y = f(x...)
    true_y = true_f(x...)
    @test isapprox(y, true_y)
    g, true_g = zeros(N), zeros(N)
    ∇f(g, x...)
    true_∇f(true_g, x...)
    @test isapprox(g, true_g)
    H, true_H = zeros(N, N), zeros(N, N)
    ∇²f(H, x...)
    true_∇²f(true_H, x...)
    @test isapprox(H, true_H)
    return
end

@testset "Examples" begin
    for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
        for (example, N) in (
            example_rosenbrock => 2,
            example_logsumexp => 3,
            example_logsumexp => 20,
        )
            @testset "$example - $N - $config" begin
                test_example(example, N, config)
            end
        end
    end
end
odow commented 6 months ago

Okay, so since this is 20 lines of code, I think this might better as a tutorial in the documentation.

@blegat has asked for this before: https://github.com/jump-dev/JuMP.jl/issues/2348#issuecomment-1822073111

It'll also let us show off Enzyme and ForwardDiff.

I'll take a stab, and then we can discuss the relative merits of having the code as a JuMP extension vs asking people to copy-paste a snippet.

odow commented 6 months ago

Developer call says that the documentation https://jump.dev/JuMP.jl/dev/tutorials/nonlinear/operator_ad/ is sufficient.