microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Issue in reproducing the training loss vs learning rates curve #30

Closed NicolasWinckler closed 1 year ago

NicolasWinckler commented 1 year ago

Hi, First of all, thanks for sharing your work.

We tried to reproduce the expected behavior of muP, using ResNet18 and the CIFAR10, as provided in the main script of your repository. The idea was to launch a training, for multiple learning rates and width_mult, and get the minimum loss each time, as you did in your paper, to ensure that the best learning rate doesn't change with a different width_mult.

We modified a bit the main.py script, to skip the saving/loading of the base shape file, as follows:

'''Train CIFAR10 with PyTorch.'''
import argparse
import os
from time import gmtime, strftime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from mup import MuAdam, MuSGD, get_shapes, set_base_shapes
from copy import deepcopy
from mup.infshape import InfShape
from mup.shape import clear_dims, zip_infshapes
from torch.utils.tensorboard import SummaryWriter
import resnet

# Training
def train(epoch, net, writer):
#    from utils import progress_bar
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    writer.add_scalar("Train/Loss", train_loss/(batch_idx+1), epoch)
    writer.add_scalar("Train/Acc", 100.*correct/total, epoch)

    return train_loss/len(trainloader)

def test(epoch, net, writer):
#    from utils import progress_bar
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    writer.add_scalar("Test/Loss", test_loss, epoch)
    writer.add_scalar("Test/Acc", 100.*correct/total , epoch)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

    return test_loss/len(testloader), best_acc

# Custom method to skip save and load shapes
def get_base_shapes(base_shapes, delta_shapes):
    model_or_shapes = clear_dims(zip_infshapes(base_shapes, delta_shapes))
    if isinstance(model_or_shapes, nn.Module):
        sh = get_infshapes(model_or_shapes)
    elif isinstance(model_or_shapes, dict):
        sh = deepcopy(model_or_shapes)
    else:
        raise ValueError()
    sh = {k: s.base_shape() for k, s in sh.items()}
    return {k: InfShape.from_base_shape(v) for k, v in sh.items()}

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description=''
    '''
    PyTorch CIFAR10 Training, with μP.
    To save base shapes info, run e.g.
        python main.py --save_base_shapes resnet18.bsh --width_mult 1
    To train using MuAdam (or MuSGD), run
        python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd}
    To test coords, run
        python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check
        python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check
    If you don't specify a base shape file, then you are using standard parametrization, e.g.
        python main.py --width_mult 2 --optimizer {muadam,musgd}
    Here muadam (resp. musgd) would have the same result as adam (resp. sgd).
    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    parser.add_argument('--arch', type=str, default='resnet18')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam'])
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument('--width_mult', type=float, default=1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--test_batch_size', type=int, default=128)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--test_num_workers', type=int, default=2)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=1111, help='random seed')

    args = parser.parse_args()

    root_dir = "/out/"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='../dataset', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(
        root='../dataset', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('==> Building model..')
    net = getattr(resnet, args.arch)(wm=args.width_mult)
    net = net.to(device)
    if args.optimizer in ["musgd","muadam"]:
        print(f'using muP Parametrization')
        base_shapes = get_shapes(net)
        delta_shapes = get_shapes(getattr(resnet, args.arch)(wm=args.width_mult/2))
        dict_infshape = get_base_shapes(base_shapes, delta_shapes)

        set_base_shapes(net, dict_infshape)
    else:
        print(f'using Standard Parametrization')
        set_base_shapes(net, None)

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'musgd':
        optimizer = MuSGD(net.parameters(), lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    elif args.optimizer == 'muadam':
        optimizer = MuAdam(net.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise ValueError()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    tb_time = strftime("%Y-%m-%d-%H:%M:%S", gmtime())
    sub_dir = root_dir + tb_time + "-" + str(args.arch) +  "-" + str(args.lr) + "-" + str(args.width_mult)

    os.makedirs(sub_dir, exist_ok = True)

    writer = SummaryWriter(sub_dir)

    for epoch in range(start_epoch, start_epoch+args.epochs):
        best_train_loss = train(epoch, net, writer)
        best_test_loss, best_acc = test(epoch, net, writer)
        scheduler.step()

    writer.add_hparams({"Epochs": args.epochs, "Width": args.width_mult, "BatchSize": args.batch_size}, {"Test/Score": best_acc})

Then, for each width multiplier wm from 1 to 5, we launched the following bash scripts, which train the models for a set of learning rates.

In muP mode:

#!/bin/bash
wm=5
cd /exp/ResNetTest
pip3 install mup
for lr in 6.10351562e-05 8.13100441e-05 1.08319921e-04 1.44302039e-04 \
          1.92236834e-04 2.56094789e-04 3.41165320e-04 4.54494899e-04 \
          6.05470725e-04 8.06598270e-04 1.07453712e-03 1.43148090e-03 \
          1.90699561e-03 2.54046859e-03 3.38437101e-03 4.50860411e-03 \
          6.00628919e-03 8.00148094e-03 1.06594430e-02 1.42003369e-02 \
          1.89174582e-02 2.52015307e-02 3.35730700e-02 4.47254987e-02 \
          5.95825832e-02 7.93749499e-02 1.05742019e-01 1.40867801e-01 \
          1.87661798e-01 2.50000000e-01
do
    echo "Training for wm = ${wm} and lr = ${lr}"
    python3 main.py --lr=$lr --epochs=150 --batch_size=128 --num_workers=4 --seed=1111 --width_mult=$wm
done

In SP mode:

#!/bin/bash
wm=5
cd /exp/ResNetTest
pip3 install mup
for lr in 2.56094789e-04 4.54494899e-04 \
          8.06598270e-04 1.43148090e-03 \
          2.54046859e-03 4.50860411e-03 \
          8.00148094e-03 1.42003369e-02 \
          2.52015307e-02 4.47254987e-02 \
          7.93749499e-02 1.40867801e-01 \
          2.50000000e-01
do
    echo "Training for wm = ${wm} and lr = ${lr}"
    python3 main.py --lr=$lr --epochs=150 --batch_size=128 --num_workers=4 --seed=1111 --width_mult=$wm --optimizer='sgd'
done

Then, we get the minimum loss and plot the two curves (loss vs lr) : one with mup, one without.

With muP :

loss-vs-lr-with-mup

Without muP :

loss-vs-lr-with-sp

As you can see on the two figures, there is no visible difference between the two scenarios: In both case, minima are aligned except for those with wm=1 Do you have an idea why it is happening ? Thanks for your help

edwardjhu commented 1 year ago

Thanks for reaching out, Nicolas!

There are two reasons:

NicolasWinckler commented 1 year ago

Thanks for your reply Edward!

We understand that under SP our width multiplier range may be too small and that the LR alignment in SP may fail for larger width. However, what we don t understand is the misalignment under muP when comparing the curve for wm=1 with the other wm values. I thought that under muP, the LR alignment is guaranteed?

thegregyang commented 1 year ago

Adding to Edward, you can see from the plot in our paper that the difference on ResNet on CIFAR10 is not as drastic as on transformers. The width difference is 16x btw the largest and smallest model here.

image

In addition, if you tune the input weight and/or output weight learning rate separately from the global learning rate on the smallest model, then often you'll be able to see muP performing much better and the alignment better.

thegregyang commented 1 year ago

Thanks for your reply Edward!

We understand that under SP our width multiplier range may be too small and that the LR alignment in SP may fail for larger width. However, what we don t understand is the misalignment under muP when comparing the curve for wm=1 with the other wm values. I thought that under muP, the LR alignment is guaranteed?

The answer is twofold: 1) the alignment is only approximate, improving with width of the base model (the analogy is estimating the mean of a population by taking a large sample and calculating its average --- this average is only approximately the same as the population mean, and only when the sample size is large enough), and 2) this has to do with insufficient tuning of other hyperparameters like input/output LR like I mentioned. You can check out my reply here.

In particular,

The true hyperparameter space here is the very high dimensional space containing [learning rate, initialization] (we can insert multipliers here as well, but like I said, it is redundant) for every parameter tensor (weights, biases, gains, etc). If you were to tune all these hyperparameters and obtain the optimal combination, then this combination is guaranteed to be stable in some sense as you vary width (in muP). However, in practice, we may not want to tune that many hyperparameters because of resource constraints. So we combine hyperparameters (by e.g., tying learning rate for many weights together) until we have only a small number to tune. This essentially means that we are now focusing on a low dimensional slice of the true hyperparameter space --- that we guess should contain all the really good hyperparameters. The choices of hyperparameters we tuned in our paper exemplify the “low dimensional slice” we chose. These choices are based on our empirical experience tuning hyperparameters, but over time people may find better choices.

NicolasWinckler commented 1 year ago

Thank you Greg for your detailed explanation!