microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Reproducing the validation accuracy vs learning rates curve on ResNet #67

Open liulei277 opened 6 months ago

liulei277 commented 6 months ago

Hello! We tried to reproduce the experiment in your paper (Figure 16, ResNet on CIFAR-10 for different widths (compared to a base network). We made some modifications to examples/ResNet/main.py:

for width_mult in [0.5, 1.0, 2.0, ]:
        for log2lr in np.linspace(-3, 0, 7): 
             net = getattr(resnet, args.arch)(wm=width_mult)
             ...
             if args.optimizer == 'musgd':
                 optimizer = MuSGD(net.parameters(), lr=2**log2lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
            ...

And we ran the following commands:

# mup
python main.py --load_base_shapes resnet18.bsh

Then we got the following picture: image

Is there anything wrong in our implementation? Thanks.

liulei277 commented 6 months ago

What's more, we ran the following commands with the default examples/ResNet/main.py:

# mup
python main.py --load_base_shapes resnet18.bsh --lr 0.5 --width_mult 0.5

After running 10 epochs, the learning rate we obtain is 82.14%. It's different from the accuracy(92.78%) in your paper Table 12: ResNet on CIFAR10.