microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

MuP Coord Check not Working with Electra Style Model #27

Closed zanussbaum closed 1 year ago

zanussbaum commented 1 year ago

I'm trying to use an Electra-Style model with µP but am not able to get a the coord plots to work correctly. Currently, I have Readout layers on both the Discriminator and Generator.

Creating coord checks for the Discriminator and Generator alone seem to work, but when combined the µP plot does not seem as expected.

Generator coord checks: μp_electra_generator_adam_lr0 001_nseeds5_coord sp_electra_generator_adam_lr0 001_nseeds5_coord

Discriminator coord checks: μp_electra_adam_lr0 001_nseeds5_coord sp_electra_adam_lr0 001_nseeds5_coord

Electra Model coord checks:

sp_electra_model_adam_lr0 001_nseeds5_coord μp_electra_model_adam_lr0 001_nseeds5_coord

Will µP not work for "multi-task" losses like here where the overall loss is a weighted sum of mlm_loss and disc_loss?

thegregyang commented 1 year ago

Can you clarify how you are combining the generator and discriminator to get the 3rd set of plots?

On Sun, Nov 6, 2022, 7:15 PM Zach Nussbaum @.***> wrote:

I'm trying to use an Electra-Style model https://github.com/lucidrains/electra-pytorch with µP but am not able to get a the coord plots to work correctly. Currently, I have Readout layers on both the Discriminator and Generator.

Creating coord checks for the Discriminator and Generator alone seem to work, but when combined the µP plot does not seem as expected.

Generator coord checks: [image: μp_electra_generator_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200189965-5985e986-4676-46fa-9d1a-79ced3e862b1.jpg [image: sp_electra_generator_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200189966-3de13deb-84be-42aa-aa6a-7c60dcec5158.jpg

Discriminator coord checks: [image: μp_electra_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200189979-e6050c63-2dfb-4b51-965c-e23ce451e6bf.jpg [image: sp_electra_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200189980-31967b38-f2dd-4545-9954-43552b7c9168.jpg

Electra Model coord checks:

[image: sp_electra_model_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200190367-03c6f84a-b336-4fc9-8441-17b59d56eff4.jpg [image: μp_electra_model_adam_lr0 001_nseeds5_coord] https://user-images.githubusercontent.com/33707069/200190369-7fb44d98-b0eb-4421-87e9-e175dbbe57cf.jpg

Will µP not work for "multi-task" losses like here where the overall loss is a weighted sum of mlm_loss and disc_loss?

— Reply to this email directly, view it on GitHub https://github.com/microsoft/mup/issues/27, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMWHHM3NAX233NYF2SNVMQDWG7YVLANCNFSM6AAAAAARYR6MVQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

zanussbaum commented 1 year ago

Similar to this, we use the logits from the Generator to sample/replace tokens for the input to the Discriminator. The Discriminator tries to predict which tokens have been replaced

thegregyang commented 1 year ago

Can you tell me which layers in the 3rd set of plots are seeing exploding values (already at initialization)?

On Sun, Nov 6, 2022, 7:46 PM Zach Nussbaum @.***> wrote:

Similar to this https://github.com/lucidrains/electra-pytorch/blob/master/electra_pytorch/electra_pytorch.py#L190-L218, we use the logits from the Generator to sample/replace tokens for the input to the Discriminator. The Discriminator tries to predict which tokens have been replaced

— Reply to this email directly, view it on GitHub https://github.com/microsoft/mup/issues/27#issuecomment-1304878254, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMWHHM3343NEEURILYNSTBDWG74KFANCNFSM6AAAAAARYR6MVQ . You are receiving this because you commented.Message ID: @.***>

zanussbaum commented 1 year ago

It looks like the attention layers such as discriminator.electra.encoder.layer.0.attention.output.dense, discriminator.electra.encoder.layer.0.attention.output.dropout, discriminator.electra.encoder.layer.0.attention.self.key, discriminator.electra.encoder.layer.0.attention.self.query discriminator.electra.encoder.layer.0.attention.self.value, This seems to be present in the generator layers too. However it seems a bit odd to me that the individual layers seem to blow up but the full layer seems to be constant?

https://docs.google.com/spreadsheets/d/1vd_cVkNAbr0jSLax_IrH4sjjIcanxfOXar6cbYVS3DE/edit?usp=sharing

thegregyang commented 1 year ago

That's strange because at least the generator should behave the same at initialization whether you combine it with the discriminator or not, because the computation it does is exactly the same? Can you clarify what data is fed in when you combine and when you don't?

On Sun, Nov 6, 2022, 8:28 PM Zach Nussbaum @.***> wrote:

It looks like the attention layers such as discriminator.electra.encoder.layer.0.attention.output.dense, discriminator.electra.encoder.layer.0.attention.output.dropout, discriminator.electra.encoder.layer.0.attention.self.key, discriminator.electra.encoder.layer.0.attention.self.query discriminator.electra.encoder.layer.0.attention.self.value, This seems to be present in the generator layers too. However it seems a bit odd to me that the individual layers seem to blow up but the full layer seems to be constant?

https://docs.google.com/spreadsheets/d/1vd_cVkNAbr0jSLax_IrH4sjjIcanxfOXar6cbYVS3DE/edit?usp=sharing

— Reply to this email directly, view it on GitHub https://github.com/microsoft/mup/issues/27#issuecomment-1304886868, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMWHHM5ZP22LJ66HMPRRUB3WHABEJANCNFSM6AAAAAARYR6MVQ . You are receiving this because you commented.Message ID: @.***>

zanussbaum commented 1 year ago

For inputs, the data fed into the model is roughly the same. The main differences I see is which tokens are masked, but I don't imagine that having a large impact on what's fed into the generator. The only difference from before is that masking is now handled within the Electra class instead of within the DataCollator. I'll do some more debugging because it would make sense that the generator layers should have same L1s.

Another difference is that instead of directly backpropagating with respect to the loss on the generator, we backprop with respect to the weighted sum of the generator and discriminator, but again I agree that this shouldn't affect the coord checks at least for t==1 for the generator

zanussbaum commented 1 year ago

@thegregyang sorry for the previous brief comment, I updated with some (hopefully) more useful info

zanussbaum commented 1 year ago

Ok I found the bug in my code! I was doing the _init_weights before set_base_shapes and came across a comment saying that it's needed to be called after set_base_shapes. Thanks for the help pointing me in the right direction 😄 @thegregyang