FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
118 stars 25 forks source link

Switch to ParameterSchedulers.jl #106

Closed darsnack closed 2 years ago

darsnack commented 2 years ago

ParameterSchedulers.jl started in response to the limitations of Animations.jl for hyper-parameter scheduling. The API for ParameterSchedulers.jl has started to stabilize, so maybe we can consider swapping the backend?

The following snippet is some code that worked for a previous version of FluxTraining.jl. Perhaps an enterprising user can adapt it into a complete PR? If not, I'm filing this so that I remember to come back when I have the time.

"""
    Scheduler(schedules...)

Callback for hyperparameter scheduling.
Takes a pair of hyperparameters and schedules from ParameterSchedulers.

## Example
```julia
lrschedule = Exp(0.1, 0.5)
scheduler = Scheduler(
    LearningRate => lrschedule
)

""" mutable struct Scheduler <: Callback schedules::Dict{Type{<:HyperParameter}, ParameterSchedulers.AbstractSchedule} step::Int Scheduler(args...; kwargs...) = new(Dict(args...; kwargs...), 1) end

Base.show(io::IO, scheduler::Scheduler) = print(io, "Scheduler(", join(keys(scheduler.schedules), ", "), ")")

function FluxTraining.stateaccess(scheduler::Scheduler)

TODO: implement proper merging of permissions

if length(keys(scheduler.schedules)) == 0
    hpstateaccess = (;)
else
    hpstateaccess = merge(FluxTraining.stateaccess.(keys(scheduler.schedules))...)
end
return (data = Read(), cbstate = (; hyperparams = Write(), history = Read()),
        hpstateaccess...)

end

function FluxTraining.init!(scheduler::Scheduler, learner) if !haskey(learner.cbstate, :hyperparams) learner.cbstate.hyperparams = ValueHistories.MVHistory() end scheduler.step = 1

return scheduler

end

function FluxTraining.on(::StepBegin, phase::AbstractTrainingPhase, scheduler::Scheduler, learner) for (H, schedule) in scheduler.schedules value = schedule(scheduler.step) FluxTraining.sethyperparameter!(learner, H, value) push!( learner.cbstate.hyperparams, Symbol(H), learner.cbstate.history[phase].steps, value) end scheduler.step += 1 end

lorenzoh commented 2 years ago

Sounds good to me!