google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.09k stars 814 forks source link

[Question][Feature Request] Support for fine tuning #1284

Open Sirsirious opened 3 years ago

Sirsirious commented 3 years ago

Description

After venturing a while with trax for novel and some predefined models I started looking into fine tuning.

As far as I've seen, I can use a pretrained model as a part of a bigger model (since models are nothing else but stacked layers) and could somehow adapt them.

But I also understand that there are other ways of fine tuning, such as freezing some layers and changing others (e.g.: the input and output layers).

So far I didn't find a way to do it. Am I missing something or is this a feature yet to be included?

Environment information

OS: Google Colabs (linux?)

$ pip freeze | grep trax
#trax==1.3.6
bhuvan1643 commented 3 years ago

Yes even I'm looking for the same thing to fine-tune by freezing initial layers

boathit commented 3 years ago

You can freeze one or more layers simply by using tl.Fn. Suppose you have a model m0 = tl.Serial(m1, m2, m3) with three layers m1, m2, m3 and wish to freeze m2 during training, this can be achieved like this

from trax.shapes import signature
from trax import layers as tl
from jax import random

def FnLayer(m):
    return tl.Fn("FnLayer", lambda x: m(x))

m1 = tl.Dense(3)
m2 = tl.Dense(2)
m3 = tl.Dense(1)

## This is your model
m0 = tl.Serial(m1, m2, m3)
## Freezing weights of m2
m4 = tl.Serial(m1, FnLayer(m2), m3)
## xs is one batch data 
m0.init(signature(xs), rng=random.PRNGKey(0))

In the last step, we initialize m2 via initializing m0 and it is important because calling m4.init() cannot get m2 initialized. Now you can throw m4 to the trax trainer which will only update the weights of m1 and m3.

You can observe the weights difference between m0 and m4 as

from jax import tree_map

tree_map(lambda w: w.shape, m0.weights)
## The weights of m2 disappear from m4
tree_map(lambda w: w.shape, m4.weights)
Sirsirious commented 3 years ago

Hi @boathit, I undestand what you're proposing and it's a nice solution for self-made models. But that is of no help if I want to load a predefined model (such as a Reformer), load pretrained weights and then freeze some of the layers... I'd expect a more sophisticated way to work in. Maybe, if the layers were indexed, it could be that we told the model to freeze layers at index x, y and z (which is basically to offer a way to tell the layers weights not to be adjusted during backprop).

boathit commented 3 years ago

That is easy. You can decompose a loaded model into a collection of sub-models and reconstruct the model using tl.Fn layer.

m0 = tl.Serial(
    tl.Dense(3),
    tl.Dense(2),
    tl.Dense(1)
)

## This would be m0.init_from_file(...) in your case
_, _ = m0.init(signature(xs), rng=random.PRNGKey(0))

## Decomposing m0 
m1 = m0.sublayers[0]
m2 = m0.sublayers[1]
m3 = m0.sublayers[2]

## Reconstructing m0 using tl.Fn
m4 = tl.Serial(m1, FnLayer(m2), m3)

When the model is complicated like Transformer, you would have to call nested sublayers as m0.sublayers[i].sublayers[j] to break down it into the parts you wish to freeze.

friesel commented 3 years ago

boathit, thx a bunch. How would you save that m4 now? The only way I get to save checkpoints is via the Trainer. However, the trainer expects a non initialized model and throws an error when you hand over the model-function AFTER initializing from file.

boathit commented 3 years ago

@friesel Did you try save_to_file function?

friesel commented 3 years ago

@boathit I checked into that. However to no avail. I am currently working on the last release (Dec 20) and there was no such function. Also my main challenge was the complexity of the entire checkpoint files being much more than mere model-weights. And training.Loop and the Trainer-Class handle things differently as well. Long story short: I hacked myself into trainer_lib.py into load_trainer_state(). So I create a new model with the Trainer(), initialize randomly, save to disk, and then when reopening that checkpoint/model, I use load_trainer_state() to replace the weights of the layers that concern me with pretrained weights from other architectures (as long as the shape fits). i.e. Embedding and the bottom-Encoder-layers of a LM-pretrain. Works like a charm even though certainly not very elegant. Thx for the sublayers-hint. That was key.

aobaruwa commented 2 years ago

That is easy. You can decompose a loaded model into a collection of sub-models and reconstruct the model using tl.Fn layer.

m0 = tl.Serial(
    tl.Dense(3),
    tl.Dense(2),
    tl.Dense(1)
)

## This would be m0.init_from_file(...) in your case
_, _ = m0.init(signature(xs), rng=random.PRNGKey(0))

## Decomposing m0 
m1 = m0.sublayers[0]
m2 = m0.sublayers[1]
m3 = m0.sublayers[2]

## Reconstructing m0 using tl.Fn
m4 = tl.Serial(m1, FnLayer(m2), m3)

When the model is complicated like Transformer, you would have to call nested sublayers as m0.sublayers[i].sublayers[j] to break down it into the parts you wish to freeze.

I already tried this but when the weight disappears, the forward method of m2 will have no weights to use since it has been frozen. The code returns an error.