microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Reproducing the training loss vs learning rates curve on MLP #52

Closed jhj0411jhj closed 12 months ago

jhj0411jhj commented 1 year ago

Hello! We tried to reproduce the experiment in your paper (Figure 3, MLP width different hidden sizes trained for 20 epoch on CIFAR-10 using SGD). We made some modifications to examples/MLP/main.py:

nonlin = torch.relu
criterion = F.cross_entropy
method = 'mup' if args.load_base_shapes else 'sp'
for width in [64, 512, 4096]:
    for lr in np.logspace(-14, -2, 13, base=2):
        ...

And we ran the following commands:

# sp
python main.py

# mup
python main.py --load_base_shapes width64.bsh 

However, we didn't observe the shift of best LR with different width in SP. There doesn't seem to be much difference between SP and MuP. Is there anything wrong in our implementation? Thanks.

SP MuP
logs_sp json logs_mup json
jhj0411jhj commented 1 year ago

Reproduce SP using our own code of MLP (without mup at all): SP

jhj0411jhj commented 1 year ago

I notice in #41 , Line 139 to 141 are doing what mup.init does manually. In default args, input_mult = output_mult = init_std = 1.0, and I didn't change these values in both SP and MuP. Does this affect the result in SP?

edwardjhu commented 1 year ago

It's important to make sure that the HPs for the narrowest model are optimal.

Can you try setting the HPs for both SP and muP according the comment in main.py?

We provide below some optimal hyperparameters for different activation/loss function combos: if nonlin == torch.relu and criterion == F.cross_entropy: args.input_mult = 0.00390625 args.output_mult = 32 elif nonlin == torch.tanh and criterion == F.cross_entropy: args.input_mult = 0.125 args.output_mult = 32 elif nonlin == torch.relu and criterion == MSE_label: args.input_mult = 0.03125 args.output_mult = 32 elif nonlin == torch.tanh and criterion == MSE_label: args.input_mult = 8 args.output_mult = 0.125

jhj0411jhj commented 12 months ago

@edwardjhu Thanks. I reproduced the result with SGD with input_mult/output_mult set. image

ordabayevy commented 6 months ago

Hi @edwardjh ! My understanding is that MLP in Fig3/Eq2 has biases for input and hidden layers and also doesn't use zero readout initialization which is different from the examples/MLP/main.py where there are no biases at all and readout layer weight are initialized at zero. Could you please provide HP values (input_mult/output_mult) for the Fig3 model with biases? I have reimplemented mup such that c=0 always for all optimizers so that I can use regular torch optimizers and want to verify my implementation against your implementation (I can reproduce curves for MLP without biases from the demo notebook).

Thank you!