JuliaNLSolvers / Optim.jl

Optimization functions for Julia
Other
1.12k stars 217 forks source link

Restart and store trace #1031

Open sallyc1997 opened 1 year ago

sallyc1997 commented 1 year ago

Hi! I'm using this package for my research that involves estimating a complicated model. I've been using the NelderMead() algorithm. From time to time, the estimation would stop due to memory issue. I have three questions:

Thank you very much for your help!!

pkofod commented 1 year ago
  • I set store_trace=true. Could this be the cause of the memory issue?

yes

  • Is it storing to a file? Where would the storing be? Or the trace will only be stored at the end of the execution?

no, it's not and you can't really do it

  • Is there anyway I can store the trace and restart with that?

you can only do it by running for a number of iterations and starting from that final point... sorry. Instead of using the trace system you can also store information in your objective function and control it that way. Then you could save it to a file and not worry about memory..

KnutAM commented 1 year ago

Is there anyway I can store the trace and restart with that?

I hacked around to do this, in case it is helpful, but using some internals and serializing the state of the optimizer. Not sure how well it generalizes...

using Optim, UUIDs, Serialization

mutable struct SaveStateWrapper{F,S}
    const obj::F # Objective function 
    const optimstate::S
    const filename::String
    const num_calls_per_save::Int
    num_calls_since_save::Int
end
function SaveStateWrapper(obj, optimstate; num_calls_per_save)
    filename = string(uuid1())*".state"
    @info "Creating SaveStateWrapper with filename", filename
    return SaveStateWrapper(obj, optimstate, filename, num_calls_per_save, 0)
end

function save_optim_state(filename, state)
    tmp_file = filename*"_tmp"
    isfile(tmp_file) && rm(tmp_file)
    isfile(filename) && mv(filename, tmp_file)
    serialize(filename, state)
    isfile(tmp_file) && rm(tmp_file)
end

function (ssw::SaveStateWrapper)(args...; kwargs...)
    ssw.num_calls_since_save += 1
    if ssw.num_calls_since_save > ssw.num_calls_per_save
        save_optim_state(ssw.filename, ssw.optimstate)
        ssw.num_calls_since_save = 0
    end
    ssw.obj(args...; kwargs...)
end

function optimize_with_restart(obj, x0, method, options; 
        inplace = true, autodiff = :finite, # Optim settings
        num_calls_per_save=10,               # wrapper settings
        state=nothing
        )
    if state===nothing
        the_state = Optim.initial_state(method, options, Optim.promote_objtype(method, x0, autodiff, inplace, obj), x0)
    else
        the_state = state
    end
    wrapped_obj = SaveStateWrapper(obj, the_state; num_calls_per_save)
    real_obj = Optim.promote_objtype(method, x0, autodiff, inplace, wrapped_obj)
    return Optim.optimize(real_obj, x0, method, options, the_state)
end

The following test shows that it works. Must first get the output from the first run, and then save the *.state filename to test the restart.


# Create a special objective function around `sum(x)`, that will 
# 1) throw an error when I want to
# 2) record the history of objective values. 
struct MyObj
    vals::Vector{Float64}
    fail_at::Int
end
MyObj(;fail_at=10) = MyObj(Float64[], fail_at)
function (m::MyObj)(x)
    o = sum(x)
    push!(m.vals, o)
    length(m.vals) >= m.fail_at && error("Planned failure")
    return o
end

# o1 will fail after 10 function calls
o1 = MyObj()
try
    r = optimize_with_restart(o1, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)    
catch e
    println(e) # Simulate failure
end

# Change to the path outputted during the creation of SaveStateWrapper
state_file = "de16fbc0-352e-11ee-3458-2b759117f9c2.state" 

# o2 will fail after 20 calls, but should restart from about where o1 left off
state = deserialize(state_file)
o2 = MyObj(;fail_at=20)
try
    r2 = optimize_with_restart(o2, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1, state=state)
catch e
    println(e)  # Simulate failure
end

offset = 3 # Not sure this isn't zero or 1...
restarted_trace = append!(copy(o1.vals), o2.vals[(1+offset):end])

# o3 runs from beginning without interruption (up to 30 function calls)
o3 = MyObj(;fail_at=30)
try
    r3 = optimize_with_restart(o3, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)
catch e
    println(e)  # Simulate failure
end

for i in 1:length(restarted_trace)
    println("$i: ", o3.vals[i], ", ", restarted_trace[i], ". Same? ", o3.vals[i]≈restarted_trace[i])
end