openai / weak-to-strong

MIT License
2.5k stars 305 forks source link

The schedule.step() should be called outside the dataloader loop. #15

Open wangguanan opened 10 months ago

wangguanan commented 10 months ago

Thanks to the OpenAI and Superalignment Generalization Team's awesome work.

When I reading the code of vision part, I found a minor bug about CosineAnnealingLR. Since the learning rate schedule is set by n_epochs not n_iters,

schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)

the schedule.step() should be called outside train_loader loop, corespondingly:

    for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")):
        correct, total = 0, 0
        for x, y in train_loader:
            x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            schedule.step() # <-- remove
            if len(y.shape) > 1:
                y = torch.argmax(y, dim=1)
            correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()
            total += len(y)
        schedule.step() # <-- add
        pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}")

After fixing the logic, the final results should be like this:

Model Top-1 Accuracy Top-1 Acc (schedule outside)
AlexNet 56.6 -
Dino ResNet50 63.7 -
Dino ViT-B/8 74.9 -
AlexNet → DINO ResNet50 60.7 61.9 (+1.2)
AlexNet → DINO ViT-B/8 64.2 67.1 (+2.9)
WuTheFWasThat commented 10 months ago

@pavel-izmailov

wangguanan commented 10 months ago

One more thing, using multiple processings can significantly increase data loading speed, i.e. reduce training and inference time, which can be implemented by setting num_workers > 0, correspondingly:

here

def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM):
    ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform)
    loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8)
    return ds, loader

here

train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8)
WuTheFWasThat commented 10 months ago

that's great, feel free to make PRs!