Open Sirsirious opened 3 years ago
Yes even I'm looking for the same thing to fine-tune by freezing initial layers
You can freeze one or more layers simply by using tl.Fn
. Suppose you have a model m0 = tl.Serial(m1, m2, m3)
with three layers m1, m2, m3
and wish to freeze m2
during training, this can be achieved like this
from trax.shapes import signature
from trax import layers as tl
from jax import random
def FnLayer(m):
return tl.Fn("FnLayer", lambda x: m(x))
m1 = tl.Dense(3)
m2 = tl.Dense(2)
m3 = tl.Dense(1)
## This is your model
m0 = tl.Serial(m1, m2, m3)
## Freezing weights of m2
m4 = tl.Serial(m1, FnLayer(m2), m3)
## xs is one batch data
m0.init(signature(xs), rng=random.PRNGKey(0))
In the last step, we initialize m2
via initializing m0
and it is important because calling m4.init()
cannot get m2
initialized. Now you can throw m4
to the trax trainer which will only update the weights of m1
and m3
.
You can observe the weights difference between m0
and m4
as
from jax import tree_map
tree_map(lambda w: w.shape, m0.weights)
## The weights of m2 disappear from m4
tree_map(lambda w: w.shape, m4.weights)
Hi @boathit, I undestand what you're proposing and it's a nice solution for self-made models. But that is of no help if I want to load a predefined model (such as a Reformer), load pretrained weights and then freeze some of the layers... I'd expect a more sophisticated way to work in. Maybe, if the layers were indexed, it could be that we told the model to freeze layers at index x, y and z (which is basically to offer a way to tell the layers weights not to be adjusted during backprop).
That is easy. You can decompose a loaded model into a collection of sub-models and reconstruct the model using tl.Fn
layer.
m0 = tl.Serial(
tl.Dense(3),
tl.Dense(2),
tl.Dense(1)
)
## This would be m0.init_from_file(...) in your case
_, _ = m0.init(signature(xs), rng=random.PRNGKey(0))
## Decomposing m0
m1 = m0.sublayers[0]
m2 = m0.sublayers[1]
m3 = m0.sublayers[2]
## Reconstructing m0 using tl.Fn
m4 = tl.Serial(m1, FnLayer(m2), m3)
When the model is complicated like Transformer, you would have to call nested sublayers
as m0.sublayers[i].sublayers[j]
to break down it into the parts you wish to freeze.
boathit, thx a bunch. How would you save that m4 now? The only way I get to save checkpoints is via the Trainer. However, the trainer expects a non initialized model and throws an error when you hand over the model-function AFTER initializing from file.
@friesel Did you try save_to_file
function?
@boathit I checked into that. However to no avail. I am currently working on the last release (Dec 20) and there was no such function. Also my main challenge was the complexity of the entire checkpoint files being much more than mere model-weights. And training.Loop and the Trainer-Class handle things differently as well. Long story short: I hacked myself into trainer_lib.py into load_trainer_state(). So I create a new model with the Trainer(), initialize randomly, save to disk, and then when reopening that checkpoint/model, I use load_trainer_state() to replace the weights of the layers that concern me with pretrained weights from other architectures (as long as the shape fits). i.e. Embedding and the bottom-Encoder-layers of a LM-pretrain. Works like a charm even though certainly not very elegant. Thx for the sublayers-hint. That was key.
That is easy. You can decompose a loaded model into a collection of sub-models and reconstruct the model using
tl.Fn
layer.m0 = tl.Serial( tl.Dense(3), tl.Dense(2), tl.Dense(1) ) ## This would be m0.init_from_file(...) in your case _, _ = m0.init(signature(xs), rng=random.PRNGKey(0)) ## Decomposing m0 m1 = m0.sublayers[0] m2 = m0.sublayers[1] m3 = m0.sublayers[2] ## Reconstructing m0 using tl.Fn m4 = tl.Serial(m1, FnLayer(m2), m3)
When the model is complicated like Transformer, you would have to call nested
sublayers
asm0.sublayers[i].sublayers[j]
to break down it into the parts you wish to freeze.
I already tried this but when the weight disappears, the forward method of m2 will have no weights to use since it has been frozen. The code returns an error.
Description
After venturing a while with trax for novel and some predefined models I started looking into fine tuning.
As far as I've seen, I can use a pretrained model as a part of a bigger model (since models are nothing else but stacked layers) and could somehow adapt them.
But I also understand that there are other ways of fine tuning, such as freezing some layers and changing others (e.g.: the input and output layers).
So far I didn't find a way to do it. Am I missing something or is this a feature yet to be included?
Environment information