mlr-org / mlr3torch

Deep learning framework for the mlr3 ecosystem based on torch
https://mlr3torch.mlr-org.com
Other
40 stars 7 forks source link

unfreezing 🥶 weights callback #297

Open sebffischer opened 1 month ago

sebffischer commented 1 month ago

When finetuning a predefined image network on a downstream task, one often wants to freeze some weights for a given number of epochs/steps. As this is relatively common, we should offer a predefined callback ("cb.freeze") to do enable this.

The callback should be able to iteratively unfreeze layers after a given number of epochs / batches.

Background:

Each torch module represents its parameters as a named list():

net = torch::nn_linear(1, 1)

net$parameters
#> $weight
#> torch_tensor
#>  0.3468
#> [ CPUFloatType{1,1} ][ requires_grad = TRUE ]
#> 
#> $bias
#> torch_tensor
#>  0.6796
#> [ CPUFloatType{1} ][ requires_grad = TRUE ]

When we want to unfreeze a specific weight, we can refer to it via its name in this list. Further, we can freeze a parameter in a network by setting its $requires_grad field to FALSE:

net$parameters[[1]]$requires_grad
#> TRUE
net$parameters[[1]]$requires_grad_(FALSE)
net$parameters[[1]]$requires_grad
#> FALSE

We can unfreeze a parameter the same way:

net$parameters[[1]]$requires_grad_(TRUE)
net$parameters[[1]]$requires_grad
#> TRUE

The callback needs to define

  1. when to unfreeze which layer (the when should be definable both in terms of epochs and batches). It should e.g. be possible to unfreeze layer8 after the first epoch, layer7 after the third, and the rest after the third epoch.
  2. which weights are freezed at the start / which weights are trainable from the start.

I can e.g. imagine this callback to have the parameters:

cxzhang4 commented 6 days ago

We need to implement our own Selector-like class (SelectParamNames?) that allows us to select some of the parameters in a neural network. They will be HOFs.