FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Hard error on EarlyStopping() #159

Open cirobr opened 7 months ago

cirobr commented 7 months ago

Cheers. When training a model with FluxTraining.fit!(learner, epochs) and an early stop condition is met, I am having a hard error that causes the Julia script to be teminated, which prevents execution of code lines placed after the fit! command. I believe this is unintended behavior, please kindly verify. Thanks in advance.

Code is as follows (early stop parameters purposedly set to small numbers):

ms = [accuracy,
      t.Metric(LibML.IoU, device=gpu, name="IoU"),
]

cbs = [ToGPU(),
       StopOnNaNLoss(),
       Checkpointer(modelsfolder),
       EarlyStopping(1),
       EarlyStopping(NumberSinceBest(1)),
       EarlyStopping(Threshold(0.5)),
       Metrics(ms...),
       LogMetrics(TensorBoardBackend(tbfolder)),
       ]

learner = t.Learner(model, lossFunction;
                    data=(trainset, validset),
                    optimizer=modelOptimizer,
                    callbacks=cbs,
)

epochs = 100
FluxTraining.fit!(learner, epochs)
@info "project finished"

Error message as follows:

ERROR: CancelFittingException("Stop triggered by EarlyStopping.Patience(1) stopping criterion. ")
Stacktrace:
 [1] on(::FluxTraining.Events.EpochEnd, phase::ValidationPhase, cb::EarlyStopping, learner::FluxTraining.Protected{Learner})
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/earlystopping.jl:72
 [2] _on(e::FluxTraining.Events.EpochEnd, p::ValidationPhase, cb::EarlyStopping, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
 [3] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.EpochEnd, phase::ValidationPhase, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
 [4] (::FluxTraining.var"#handlefn#81"{Learner, ValidationPhase})(e::FluxTraining.Events.EpochEnd)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:102
 [5] runepoch(epochfn::FluxTraining.var"#71#72"{…}, learner::Learner, phase::ValidationPhase)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:106
 [6] epoch!
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22 [inlined]
 [7] fit!(learner::Learner, nepochs::Int64, ::Tuple{MLUtils.DataLoader{…}, MLUtils.DataLoader{…}})
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:169
 [8] fit!(learner::Learner, nepochs::Int64)
   @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:174
 [9] top-level scope
   @ ~/projects/pascalvoc-segmentation/8-training.jl:123
Some type information was truncated. Use `show(err)` to see complete types.

julia> 
darsnack commented 7 months ago

Indeed that error should be caught not thrown by the standard validation phase, but I can't see why the current code doesn't do that. Let me try and reproduce.

cirobr commented 7 months ago

Indeed that error should be caught not thrown by the standard validation phase, but I can't see why the current code doesn't do that. Let me try and reproduce.

For the epoch the early stop condition is met, TrainingPhase() printout shows up on terminal/REPL, but ValidationPhase() printout does not.

darsnack commented 7 months ago

Ah I see the issue is that early stopping throws a CancelFittingException but the epoch and step handlers only catch CancelEpochException/CancelStepException. See here.

I guess the appropriate fix would be to adjust the code for fit!:

function fit!(learner, nepochs::Int, (trainiter, validiter))
    for i in 1:nepochs
        try
            epoch!(learner, TrainingPhase(), trainiter)
            epoch!(learner, ValidationPhase(), validiter)
        catch e
            if e isa CancelFittingException
                @debug "Fitting canceled" error = e
                break
            else
                rethrow()
            end
        end
    end
end

Should be a straightforward PR if you want to attempt it.

cirobr commented 7 months ago

Ah I see the issue is that early stopping throws a CancelFittingException but the epoch and step handlers only catch CancelEpochException/CancelStepException. See here.

I guess the appropriate fix would be to adjust the code for fit!:

function fit!(learner, nepochs::Int, (trainiter, validiter))
    for i in 1:nepochs
        try
            epoch!(learner, TrainingPhase(), trainiter)
            epoch!(learner, ValidationPhase(), validiter)
        catch e
            if e isa CancelFittingException
                @debug "Fitting canceled" error = e
                break
            else
                rethrow()
            end
        end
    end
end

Should be a straightforward PR if you want to attempt it.

A quick and dirty test just confirmed the solution works. Test code as follows. Will send a PR shortly.

@info "start training ..."
function fit!(learner, nepochs::Int, (trainiter, validiter))
    for i in 1:nepochs
        try
            epoch!(learner, TrainingPhase(), trainiter)
            epoch!(learner, ValidationPhase(), validiter)
        catch e
            if e isa CancelFittingException
                @debug "Fitting canceled" error = e
                break
            else
                rethrow()
            end
        end
    end
end

epochs = 1000
fit!(learner, epochs, (trainset, validset))
@info "project finished"
cirobr commented 7 months ago

Cheers,

Have implemented slightly different epoch training. Instead of using fit!, have used a pair of epoch! as follows:

epoch!(tr_learner, TrainingPhase(), trainset)
epoch!(v_learner, ValidationPhase(), validset)

Learners for Training and Validation are different from each other, as they point out to distinct loss functions.

Upon execution, error message shows up as below. However, if fit! is back in use, there is no error.

I am reopening the case for further analysis. Thanks in advance.

Epoch 5 TrainingPhase() ...
┌─────────────────┬───────┬────────┐
│           Phase │ Epoch │   Loss │
├─────────────────┼───────┼────────┤
│ TrainingPhase() │   5.0 │ 1129.9 │
└─────────────────┴───────┴────────┘
Epoch 5 ValidationPhase() ...
┌───────────────────┬───────┬─────────┬─────────┬──────────┐
│             Phase │ Epoch │     IoU │    Loss │ Accuracy │
├───────────────────┼───────┼─────────┼─────────┼──────────┤
│ ValidationPhase() │   5.0 │ 0.18293 │ 0.81707 │  0.23597 │
└───────────────────┴───────┴─────────┴─────────┴──────────┘
Epoch 6 TrainingPhase() ...
┌─────────────────┬───────┬─────────┐
│           Phase │ Epoch │    Loss │
├─────────────────┼───────┼─────────┤
│ TrainingPhase() │   6.0 │ 1128.28 │
└─────────────────┴───────┴─────────┘
Epoch 6 ValidationPhase() ...
ERROR: LoadError: CancelFittingException("Stop triggered by NumberSinceBest(5) stopping criterion. ")
Stacktrace:
 [1] on(::FluxTraining.Events.EpochEnd, phase::ValidationPhase, cb::EarlyStopping, learner::FluxTraining.Protected{Learner})
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/callbacks/earlystopping.jl:72
 [2] _on(e::FluxTraining.Events.EpochEnd, p::ValidationPhase, cb::EarlyStopping, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/callbacks/callback.jl:254
 [3] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.EpochEnd, phase::ValidationPhase, learner::Learner)
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/callbacks/execution.jl:12
 [4] (::FluxTraining.var"#handlefn#81"{Learner, ValidationPhase})(e::FluxTraining.Events.EpochEnd)
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/training.jl:102
 [5] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, ValidationPhase, MLUtils.DataLoader{MLUtils.MappedData{:auto, typeof(gpu), Tuple{Array{Float32, 4}, Array{Bool, 4}}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::ValidationPhase)
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/training.jl:106
 [6] epoch!(learner::Learner, phase::ValidationPhase, dataiter::MLUtils.DataLoader{MLUtils.MappedData{:auto, typeof(gpu), Tuple{Array{Float32, 4}, Array{Bool, 4}}}, Random._GLOBAL_RNG, Val{nothing}})
   @ FluxTraining ~/.julia/packages/FluxTraining/1XVfn/src/training.jl:22
 [7] top-level scope
   @ ~/projects/knowledge-distillation/training.jl:170
darsnack commented 7 months ago

That's expected. Your new training loop must catch the fitting exception, since it is taking over the role of the outer loop from the package. Exceptions that are caught at an epoch! level should only cancel that iteration (epoch) of the outer loop.