Open AdarshKumar712 opened 4 years ago
There's a trend ( #666 ) to not recommend the usage of train!
in favor of plain for-loops for advanced usage. Nevertheless, train!
still can be a convenient API, could you share what the ideal API in your mind wrt this?
What I'd like to see is that even if we wrap everything into a simple train!
, I can still:
I am thinking of something like train!( loss, params[model], batch, cb, epochs, verbose)
. This way we will have in-dependency to choose any of the methods we want for trainin(choosing epochs = 1, otherwise). Actually the main reason to think of that is, I was writing some callback functions in Flux, but while using Flux.stop()
, I observed that Flux.stop()
can only stop the execution for that iteration only. But next time it again repeats the process.
For example,
julia> function terminateOnNaN(x,y)
if isnan(loss_(x,y))
@info "NaN loss, Terminating execution!!!!"
Flux.stop()
end
end
terminateOnNaN (generic function with 1 method)
julia>Flux.@epochs 6 Flux.train!(loss_,params(model),[(x,y)],opt,cb = ()->terminateOnNaN(x,y))
[ Info: Epoch 1
[ Info: NaN loss, Terminating execution!!!!
[ Info: Epoch 2
[ Info: NaN loss, Terminating execution!!!!
[ Info: Epoch 3
[ Info: NaN loss, Terminating execution!!!!
[ Info: Epoch 4
[ Info: NaN loss, Terminating execution!!!!
[ Info: Epoch 5
[ Info: NaN loss, Terminating execution!!!!
[ Info: Epoch 6
[ Info: NaN loss, Terminating execution!!!!
This is something weird and unwanted. I want my callback to completely terminate execution once needed to stop(Similarly in case of EarlyStop callback)
Also for the above conditions that you say, I think we can add progressmeter as an option using verbose = 0 or 1. For the model_checkpoint I hope that can easily be taken care of using Callback. I have written a function for that already.
I think that the epochs function rather than being exclusively defined as a function, should be included as an argument to Flux.train!. So that one can define number of iterations to be trained for within the Flux.train! function.