FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Question regarding ProgressPrinter #139

Open hv10 opened 1 year ago

hv10 commented 1 year ago

Hi first off: wonderful package :)

I have some issues with the ProgressPrinter not showing up even when using the defaultcallbacks.

learner = Learner(model, loss; optimizer=opt, callbacks=[ToGPU()], usedefaultcallbacks=true)
FluxTraining.fit!(learner, epochs, (dl, val_dl)) # where dl, dl_val are both Flux.DataLoader objects

Do I need to do something specific when constructing the Learner which I have missed? From the code it seems like I would need to give it a Progress object, do I have to construct that myself? What requirements does my data-iterator have to fullfill to show up with the defaultcallbacks?

hv10 commented 1 year ago

Oh! I found the solution for now. I have to give in the data when the learner is constructed, so do:

learner = Learner(model, loss; optimizer=opt, callbacks=[ToGPU()], usedefaultcallbacks=true, data=(dl, dl_val))
FluxTraining.fit!(learner, epochs) 
# where dl, dl_val are both Flux.DataLoader objects

Somehow non-intuitive.

lorenzoh commented 1 year ago

Hi Noel!

You can print the callbacks on a Learner using learner.callbacks.cbs, in case you want to check if one is there.

There are two ways of passing in data:

Are you saying that in the latter case, the ProgressPrinter does not work?

hv10 commented 1 year ago

Yes, exactly that. :) When constructing in the first way the ProgressPrinter works as expected, in the second case it does only print epoch, phase and some dots (as expected when ProgressPrintergot initialized with nothing on construction of the learner).

hv10 commented 1 year ago

It seems like this line does not work as I would expect it to. https://github.com/FluxML/FluxTraining.jl/blob/f53541f57c91d727387df1d32b8e7e60415a8da0/src/callbacks/callbacks.jl#L20 It seems like it always returns nothing, in the case where data is passed to .fit!()

lorenzoh commented 1 year ago

Will look into it and let you know. Thanks for reporting the issue 👍