FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Add metadata field to `Learner` #121

Open darsnack opened 2 years ago

darsnack commented 2 years ago

This adds a "metadata" PropDict to Learner for storing information that is required for training but extraneous to the training state or callback state. This is useful for unconventional training methods (issue that I am currently dealing with). In the same way that the loss function is a "parameter" that needs to be specified to standard supervised training, the metadata field holds parameters that need to be specified for unconventional training. Of course, we can't know what these parameters will be like standard training, so instead of explicit names, we provide a container to hold them.

lorenzoh commented 2 years ago

Looks good to me 👍 out of curiosity, what kind of training do you want to use this for? This is meant for state that belongs to the training loop and not any callback, right?

Can you add a short CHANGELOG entry?

darsnack commented 2 years ago

Not state, but a hyper-parameter that belongs to the training loop and not a callback. I basically have a setup where each batch is time series that computes a loss and pseudo-gradient on each time step. These parameters control how my method updates the weights within this loop.

Normally, hyper-parameters are either part of the loss or optimizer and can be either statically closed over or scheduled. In this case, the hyper-parameter belongs to neither the loss nor the optimizer but the actual training step code.

lorenzoh commented 2 years ago

Are you using a custom training step? Then it's also possible to add a field to the Phase.

Well, this will be useful anyway.

darsnack commented 2 years ago

Ah I didn't think of that. At least for my case, adding a field to the phase will be much more intuitive, so I probably won't use this feature. I'll leave it up to you if you think it is still worth adding.

lorenzoh commented 2 years ago

Hm. I think I'll leave this unmerged until someone comes with a use case where adding a field to the phase doesn't work. Where possible, that should be the preferred way.

darsnack commented 2 years ago

I found a potential use case for this: anything stored in the phase struct can't be scheduled as a hyper-parameter. Either hyper-parameters should be extended to include the phase or the learner will need to store this information.

lorenzoh commented 2 years ago

I'd prefer passing Phase information to hyperparameters as you suggest, by turning the signature from

sethyperparameter!(learner, ::Type{<:HyperParameter}, value)

into

sethyperparameter!(learner, ::Type{<:HyperParameter}, ::Phase, value)

with a default method to make it non-breaking:

sethyperparameter!(learner, T::Type{<:HyperParameter}, ::Phase, value) = sethyperparameter!(learner, T, value)

The only other thing that would need to be changed is this line to add the phase to the call:

https://github.com/FluxML/FluxTraining.jl/blob/e9f8d938df447fff37e900c07a31d9f958f4853c/src/callbacks/scheduler.jl#L56