FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Checkpointer for best checkpoint only #147

Closed RomeoV closed 1 year ago

RomeoV commented 1 year ago

Motivation and description

Currently, we already have the Checkpointer struct, which saves the model every epoch. However, this can blow up in file size, and usually we are only interested in the best, and maybe last checkpoints.

Possible Implementation

From epoch two, we have to consider the (logical) checkpoints (prev, current, best) . If prev != best, delete prev. If metric(current) > metric(best), delete path[best].

We could implement this by making the Checkpointer callback mutable with a Dict Dict{Symbol, String}(:prev=>"", :current=>"", :best=>""). Then, implementing the above logic would roughly be

function update!(cb::Checkpointer, new_path, new_metric; ord::Base.Ordering=Base.Order.Forward)
  (cb.models[:latest] != cb.paths[:best]) && rm(cb.paths[:latest])

  cb.models[:latest] = (;path=>new_path, metric=>new_metric)

  if lt(ord, cb.models[:best].metric, cb.models[:latest].metric))
      rm(cb.models[:best].path)
      cb.models[:best] = cb.models[:latest]
  end
end
ToucheSir commented 1 year ago

This is a great feature proposal but it may take a while for someone to get around to it. Would you be interested in trying for a PR?