SciML / EasyModelAnalysis.jl

High level functions for analyzing the output of simulations
MIT License
81 stars 14 forks source link

Expose standard callback options for `datafit` #158

Open ChrisRackauckas opened 1 year ago

ChrisRackauckas commented 1 year ago

The callback should calculate:

jClugstor commented 1 year ago

Is something like this on the right track? (Very roughly)

struct TimedDatafitCallback
    tolerance
    effect_function
    elapsed
    last_callback::Dates.DateTime
    params
    test_data
end

function TimedDatafitCallback(tolerance, effect_function; elapsed = Dates.Second(5))
    TimeDatafitCallback(tolerance, effect_function, elapsed, ,typemin(Dates.DateTime), Nothing, Nothing)
end

function (cb::TimedDatafitCallback)(p, lossval, args...)
    if lossval ≤ cb.tolerance #maybe just < ?
        return true
   end

    if cb.last_callback + cb.elapsed ≤ Dates.now() 
        return false
    end

    cb.test_data = args[1]
    cb.params = p

    cb.effect_function(p, lossval, args...)

    return false
end

Basically a TimedDatafitCallback struct, with a constructor that takes a tolerance and an effect function, and has a default value for the time elapsed before running the effect function. Then I make the struct callable with the same signature as in https://docs.sciml.ai/Optimization/stable/API/solve/. Then use it as the callback for the Optimizer call inside of datafit. This also depends on having the loss function also return the solution of the ODE though, since that's needed for the signature.

I didn't see any examples of callbacks for Optimizers.jl so this might be completely off.

ChrisRackauckas commented 1 year ago

yup that's on the right track