JuliaNLSolvers / Optim.jl

Optimization functions for Julia
Other
1.11k stars 213 forks source link

How to log additional information after each iteration? #1024

Closed renatobellotti closed 1 year ago

renatobellotti commented 1 year ago

Hi,

I have a loss function that is the sum over multiple sub-losses. I'm optimising it with L-BFGS. Is there a way I could get the values of my sub-loss values after each L-BFGS iteration? Currently I'm logging the values at each loss evaluation, but this includes also the values during the line search...

I think such a thing could be implemented by allowing a callback to be called after each iteration...

As a quick fix, I tried to write my own optimisation function, but I'm not quite sure how to do this.

I understood the entry point is this function. Then I'd need to call update_state! repeatedly until convergence. Is that correct? I'm not quite sure what the parameter d is that shows up everywhere. Could somebody please explain what it is and how to get it?

renatobellotti commented 1 year ago

I figured out how to do it, so I thought I'd share my minimum working example.

# Problem setup.
function loss_parts(x)
    return Dict(
        "identity" => x[1],
        "square" => x[1]^2,
        "-tanh" => -tanh(x[1]),
    )
end

function my_loss(x)
    return sum(values(loss_parts(x)))
end

function my_loss_gradient!(gradient, x)
    gradient[:] .=  1 + 2*x[1] - (1 - tanh(x[1])^2)
end

# Initialisation of the optimisation state.
x0 = [1000.]
d = Optim.promote_objtype(alg, x0, :finite, true, my_loss, my_loss_gradient!)
state = Optim.initial_state(alg, options, d, x0);

# Optimisation.
history = []

for i in 1:10
    sublosses = loss_parts(state.x)
    push!(history, sublosses)
    Optim.update_state!(d, state, alg)
    Optim.update_g!(d, state, alg)
    Optim.update_h!(d, state, alg)
end

Of course all of this could be wrapped up in a nice function.

It is a bit unfortunate that the losses are now calculated twice: Once for logging and several times to update the state. This could probably be improved. I'm happy about any ideas!

JeffFessler commented 1 year ago

The d will be a subtype of AbstractObjective defined here, unfortunately without a docstring: https://github.com/JuliaNLSolvers/NLSolversBase.jl/blob/master/src/objective_types/abstract.jl For your problem it is probably this instance that has comments that may help you: https://github.com/JuliaNLSolvers/NLSolversBase.jl/blob/master/src/objective_types/oncedifferentiable.jl (Instead of push! you might want to preallocate history.)

renatobellotti commented 1 year ago

Thank you very much for the feedback. I will preallocate history.

pkofod commented 1 year ago

You can just log it inside your objective function.

julia> using Optim

julia> # Problem setup.
       function loss_parts(x)
           return Dict(
               "identity" => x[1],
               "square" => x[1]^2,
               "-tanh" => -tanh(x[1]),
           )
       end
loss_parts (generic function with 1 method)

julia> my_history = Tuple{Vector{Float64}, Vector{Float64}}[]
Tuple{Vector{Float64}, Vector{Float64}}[]

julia> function my_loss(x)
               push!(my_history, (x, [x[1], x[1]^2, -tanh(x[1])]))
           return sum(my_history[end][end])
       end
my_loss (generic function with 1 method)

julia> optimize(my_loss, rand(1))
 * Status: success

 * Candidate solution
    Final objective value:     1.403528e-10

 * Found with
    Algorithm:     Nelder-Mead

 * Convergence measures
    √(Σ(yᵢ-ȳ)²)/n ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    10
    f(x) calls:    23

julia> my_history
23-element Vector{Tuple{Vector{Float64}, Vector{Float64}}}:
 ([-4.2992348319219494e-5], [0.205142024116315, 0.04208325005853876, -0.20231197265987563])
 ([-1.1847081703458381e-5], [0.3327130361744725, 0.11069796444043585, -0.32095644989225186])
 ([8.158871814382496e-5], [0.07757101205815747, 0.006017261911726812, -0.0774157972692335])
 ([-1.1847081703458381e-5], [-0.17757101205815756, 0.031531464323358335, 0.17572789700863797])
 ([8.158871814382496e-5], [-0.050000000000000044, 0.0025000000000000044, 0.04995837495788002])
 ([-1.1847081703458381e-5], [-0.3051420241163151, 0.09311165488180181, 0.2960111891988321])
 ([8.158871814382496e-5], [-0.17757101205815756, 0.031531464323358335, 0.17572789700863797])
 ([-1.1847081703458381e-5], [-0.018107246985460665, 0.00032787239339247435, 0.018105268289495217])
 ([8.158871814382496e-5], [0.013785506029078715, 0.00019004017647776558, -0.013784632828789709])
 ([-1.1847081703458381e-5], [0.07757101205815747, 0.006017261911726812, -0.0774157972692335])
 ⋮
 ([8.158871814382496e-5], [-0.01013405873182582, 0.00010269914638009515, 0.01013371182634953])
 ([-1.1847081703458381e-5], [-0.00016757341478226394, 2.8080849341788677e-8, 0.00016757341321372936])
 ([8.158871814382496e-5], [0.0018257236486264473, 3.333266841153867e-6, -0.0018257216200877856])
 ([-1.1847081703458381e-5], [0.00033075085106991386, 1.0939612548347234e-7, -0.0003307508390089605])
 ([8.158871814382496e-5], [-0.0006658976806344417, 4.4341972107432897e-7, 0.0006658975822104046])
 ([-1.1847081703458381e-5], [-4.2992348319219494e-5, 1.8483420140010953e-9, 4.29923482927313e-5])
 ([8.158871814382496e-5], [8.158871814382496e-5, 6.6567189283525115e-9, -8.158871796278724e-5])
 ([-1.1847081703458381e-5], [-1.1847081703458381e-5, 1.4035334488841835e-10, 1.1847081702904122e-5])
 ([-1.1847081703458381e-5], [-1.1847081703458381e-5, 1.4035334488841835e-10, 1.1847081702904122e-5])
renatobellotti commented 1 year ago

Logging in the loss function does not work if you're interested in the iterations rather than the evaluations. The latter are rather noisy... But the custom loop works nicely.