FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
121 stars 27 forks source link

Support for Transfer-Learning/Layer-Freezing #150

Open JoshuaBillson opened 1 year ago

JoshuaBillson commented 1 year ago

Motivation and description

A common practice in machine learning is to take a pre-trained model and fine-tune it on a particular dataset. This typically involves freezing the weights in some layers while fitting the output layer(s) on the new data.

Unfortunately, this functionally appears to be incompatible with the current implementation of the ToDevice callback based on the following code:

function on(::EpochBegin, ::Phase, cb::ToDevice, learner)
    model!(learner, cb.movemodelfn(learner.model))
end

function model!(learner, model)
    learner.model = model
    learner.params = setupoptimstate(model, learner.optimizer)
end

setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model)

setupoptimstate(model, optim) = Optimisers.setup(optim, model)

This essentially means that learner.params is set to the parameters of the full model at the start of each epoch. Thus, even if we try to freeze the layers manually with Flux.freeze!(learner.params.layers[1:end-1]), this will be undone by ToDevice.

Possible Implementation

One solution that would work with Flux's new explicit optimizers would be to create a callback to freeze layers after ToDevice is executed. An example is given below:

mutable struct LayerFreezing{F} <: FluxTraining.Callback
    accessor::F
end

function FluxTraining.stateaccess(scheduler::LayerFreezing)
    return (;params = FluxTraining.Write())
end

function FluxTraining.on(
    event::FluxTraining.EpochBegin, 
    phase::FluxTraining.AbstractTrainingPhase, 
    freezer::LayerFreezing, 
    learner)
    Flux.freeze!(freezer.accessor(learner.params))
end

FluxTraining.runafter(::LayerFreezing) = (FluxTraining.ToDevice,)

However, perhaps we should consider whether it's necessary for ToDevice to move the model to the GPU at the start of every epoch. Maybe we could extend the Callback interface to allow for some one-time setup code to run before the first epoch is executed?

JoshuaBillson commented 1 year ago

I think this issue may also be related to #148. In particular, the memory leak appears to be caused by ToDevice resetting the optimizer in each epoch. We could potentially kill two birds with one stone by changing this behaviour.

vargonis commented 4 months ago

Any update on this? Also, I'd really appreciate if the potential implementation above is turned into a complete example to build upon (for users like me who know nothing about the internals of FluxTraining.jl). Thanks!