microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Are Sequentials with list comprehension handled incorrectly? #43

Open RobertBaruch opened 1 year ago

RobertBaruch commented 1 year ago

Because:

class TheModel(nn.Module):

    def __init__(self, n_token_embed, n_layers):
        super().__init__()
        n_heads = n_token_embed // 2
        n_key_size = n_token_embed
        self.token_embedding_table = nn.Embedding(len(all_symbols), n_token_embed)
        self.position_embedding_table = nn.Embedding(N_CONTEXT, n_token_embed)
        self.blocks = nn.Sequential(*[Block(n_token_embed, n_key_size, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_token_embed) # final layer norm
        self.lm_head = nn.Linear(n_token_embed, N_CATEGORIES * 2)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-21-944f2038ce18>](https://localhost:8080/#) in <cell line: 86>()
     84 delta_model = TheModel(4, N_LAYERS)
     85 model = TheModel(N_TOKEN_EMBED, N_LAYERS)
---> 86 set_base_shapes(model, base_model, delta=delta_model)
     87 
     88 m = model.to(device)

1 frames
[/usr/local/lib/python3.9/dist-packages/mup/shape.py](https://localhost:8080/#) in _zip_infshape_dict(base_shapes, shapes)
     93     basenames = set(base_shapes.keys())
     94     names = set(shapes.keys())
---> 95     assert basenames == names, (
     96         f'`base_shapes` has extra names {basenames - names}. '
     97         f'`shapes` has extra names {names - basenames}.'

AssertionError: `base_shapes` has extra names set(). `shapes` has extra names {'blocks.0.sa.heads.1.value.weight', 'blocks.2.sa.heads.1.query.weight', 'blocks.2.sa.heads.1.value.weight', 'blocks.0.sa.heads.1.key.weight', 'blocks.1.sa.heads.1.value.weight', 'blocks.1.sa.heads.1.key.weight', 'blocks.2.sa.heads.1.key.weight', 'blocks.0.sa.heads.1.query.weight', 'blocks.1.sa.heads.1.query.weight'}.
thegregyang commented 1 year ago

Hi Robert,

Not sure I understand your problem exactly since I don’t see the rest of your code, but are you creating the base shapes for one depth value L1 but then applying it to a model with depth L2 != L1? This could cause something like the error you are seeing. The depth needs to be the same between them.

From: Robert Baruch @.> Date: Monday, April 17, 2023 at 4:39 PM To: microsoft/mup @.> Cc: Subscribed @.***> Subject: [microsoft/mup] Are Sequentials with list comprehension handled incorrectly? (Issue #43)

Because:

class TheModel(nn.Module):

def __init__(self, n_token_embed, n_layers):

    super().__init__()

    n_heads = n_token_embed // 2

    n_key_size = n_token_embed

    self.token_embedding_table = nn.Embedding(len(all_symbols), n_token_embed)

    self.position_embedding_table = nn.Embedding(N_CONTEXT, n_token_embed)

    self.blocks = nn.Sequential(*[Block(n_token_embed, n_key_size, n_heads) for _ in range(n_layers)])

    self.ln_f = nn.LayerNorm(n_token_embed) # final layer norm

    self.lm_head = nn.Linear(n_token_embed, N_CATEGORIES * 2)

AssertionError Traceback (most recent call last)

in <cell line: 86>()

 84 delta_model = TheModel(4, N_LAYERS)

 85 model = TheModel(N_TOKEN_EMBED, N_LAYERS)

---> 86 set_base_shapes(model, base_model, delta=delta_model)

 87

 88 m = model.to(device)

1 frames

/usr/local/lib/python3.9/dist-packages/mup/shape.py in _zip_infshape_dict(base_shapes, shapes)

 93     basenames = set(base_shapes.keys())

 94     names = set(shapes.keys())

---> 95 assert basenames == names, (

 96         f'`base_shapes` has extra names {basenames - names}. '

 97         f'`shapes` has extra names {names - basenames}.'

AssertionError: base_shapes has extra names set(). shapes has extra names {'blocks.0.sa.heads.1.value.weight', 'blocks.2.sa.heads.1.query.weight', 'blocks.2.sa.heads.1.value.weight', 'blocks.0.sa.heads.1.key.weight', 'blocks.1.sa.heads.1.value.weight', 'blocks.1.sa.heads.1.key.weight', 'blocks.2.sa.heads.1.key.weight', 'blocks.0.sa.heads.1.query.weight', 'blocks.1.sa.heads.1.query.weight'}.

— Reply to this email directly, view it on GitHubhttps://github.com/microsoft/mup/issues/43, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AMWHHM53MTNGZQDYVUTOSQLXBXICZANCNFSM6AAAAAAXB3JN44. You are receiving this because you are subscribed to this thread.Message ID: @.***>

ricomnl commented 1 year ago

hi @thegregyang , I ran into the same issue and its also due to the depths not matching. In the paper you showed that this method also works for scaling the depth of the model. How can this be achieved? (specifically with the mutransformers submodule)