FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Record time trained, training loss, validation loss and performance #119

Closed KronosTheLate closed 2 years ago

KronosTheLate commented 2 years ago

For my application, I would love to be able to record the time trained, training loss, validation loss and classification performance at a given time-interval in the training loop. But currently, the History seems only able to store number of epochs, steps, and steps in current epoch.

Would there be a way to make the History for extendable, so that users can record anything they want?

A final detail would be that I want to record these stats only after a factor increase in training time, so that when I plot e.g. training loss again a logarithmic time scale, I get somewhat evenly distributed numbers. I am not sure how to make that happen, and I do not expect it to be built in functionality. I am just mentioning it in case it would be simple enough to implement.

lorenzoh commented 2 years ago

Hey @KronosTheLate, this is a great idea! I think the simplest solution to this is to add a callback that records the time that every step starts/ends and stores that somewhere like learner.cbstate.steptimes. Since the Metrics callback already stores the metric values for every step (in learner.cbstate.metricsstep), these could then be used together. For example, using the times as an x-axis and a metric like loss as the y-axis.

Starting point for writing this (will not necessarily run, but may give you an idea of how to go about this):


# implementation
using FluxTraining, FluxTraining.Events, FluxTraining.Phases
import FluxTraining: on, init!, stateaccess, Write

struct TimeSteps <: FluxTraining.Callback

end

stateaccess(::TimeSteps) = (; cbstate = (; timesstep = Write()))

function init!(callback, learner)
    learner.cbstate.timesstep = []
end

function on(::StepEnd, ::FluxTraining.AbstractTrainingPhase, ::TimeSteps, learner)
    push!(learner.cbstate.timesstep, time())
end

## Usage

learner = (model, lossfn; callbacks=[TimeSteps()])
fit!(learner, 10)
plot(learner.cbstate.timesstep, learner.cbstate.metricsstep[TrainingPhase()])

Hope this helps!

KronosTheLate commented 2 years ago

Thanks a lot for the quick and extensive reply. I have not gotten around to playing with FluxTraining.jl yet, but I just wanted to say that it is not forgotten.

lorenzoh commented 2 years ago

Let me know if there are any hiccups when you do! I'll close this issue for now