microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Is it possible to also scale the depth of the model? #54

Open ricomnl opened 1 year ago

ricomnl commented 1 year ago

In the paper and blog post you provide examples of scaling the number of layers / depth of the model but when I try doing this in the coordinate check function from the mutransformer package I get the following error. Is there currently a way to make this work?

AssertionError: `base_shapes` has extra names set(). `shapes` has extra names 
{'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.6.attention.output.dense.bias', 
'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.4.attention.ln.bias', 
'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.bias', 
'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.value.weight', 
'bert.encoder.layer.7.ln.weight', 'bert.encoder.layer.4.attention.self.key.bias', 
'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.5.output.dense.bias', 
'bert.encoder.layer.5.attention.ln.weight', 'bert.encoder.layer.7.attention.self.query.weight', 
'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 
'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 
'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.7.ln.bias', 
'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.key.weight', 
'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.6.attention.self.key.bias', 
'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.attention.ln.bias', 
'bert.encoder.layer.4.attention.ln.weight', 'bert.encoder.layer.7.attention.self.key.bias', 
'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.6.ln.bias', 
'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 
'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.6.attention.self.value.bias', 
'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.dense.weight', 
'bert.encoder.layer.5.ln.weight', 'bert.encoder.layer.5.intermediate.dense.bias', 
'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.4.ln.weight', 
'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.query.bias', 
'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.4.output.dense.weight', 
'bert.encoder.layer.4.ln.bias', 'bert.encoder.layer.6.attention.ln.weight', 
'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.7.attention.ln.bias', 
'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.4.attention.self.value.bias', 
'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.5.attention.ln.bias', 
'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.6.ln.weight', 
'bert.encoder.layer.7.attention.ln.weight', 'bert.encoder.layer.4.attention.output.dense.weight', 
'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.5.ln.bias', 
'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 
'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.4.attention.output.dense.bias', 
'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.6.output.dense.bias'}.
ricomnl commented 1 year ago

@thegregyang @edwardjhu I saw in the paper in section H, the following:

Of course, in both scenarios, depth, batch size, and sequence lengths can be scaled up and down as well according to Fig. 19 (though note that currently we require users to recreate the base model shape at new depths, since the number of parameters now change with depth).

I'm interpreting as:

# 1. create a base model at the lowest depth and width I want to use for training the actual model
base_config = BertConfig(
    hidden_size=128,
    num_hidden_layers=4,
)
delta_config = BertConfig(
    hidden_size=256,
    num_hidden_layers=4,
)
base_model = BertForMaskedLM(config=base_config)
delta_model = BertForMaskedLM(config=delta_config)

base_shapes = make_base_shapes(base_model, delta_model)

# 2. Initialize the model I want to run experiments with
config = BertConfig(
    hidden_size=64,
    num_hidden_layers=4,
)
model = BertForMaskedLM(config=config)

# hyperparameter sweep ...

# 3. For width, simply scale up and get better results
config = BertConfig(
    hidden_size=512,
    num_hidden_layers=4,
)
model = BertForMaskedLM(config=config)

# 4. For depth, we'd get the above error if we use a different number of layers, so instead 
# regenerate the base shapes first with same settings as in 1. but e.g. `num_hidden_layers=12`. 
# Then initialize the model and get better results at higher depth 
# (as shown empirically in the paper section 7.3 for pre-layernorm Bert)
config = BertConfig(
    hidden_size=512,
    num_hidden_layers=12,
)
model = BertForMaskedLM(config=config)

Am I missing anything?

edwardjhu commented 10 months ago

The package doesn't handle changes in depth because mup doesn't take depth as a parameter. If the base shape file contains more or fewer layers than the model you are applying it to, you get an error.

My suggestion is to use a shallow and narrow model for HP search. Then, make the model as deep as the target one, save the base shapes, and the apply the them to the target model. You should be able to use the HPs found directly on the target model this way.

ricomnl commented 10 months ago

ah yes, so I tried that and it somewhat worked but the result was the below: 4 layers < 8 < 32 < 16

Screenshot 2023-08-19 at 15 35 11
Pixelatory commented 5 months ago

Interesting that the 32 layer is worse than 16 layers in these experiments

yzlnew commented 4 months ago

You should check out Tensor Program VI