The schedule.step() should be called outside the dataloader loop. #15

Open wangguanan opened 10 months ago

wangguanan commented 10 months ago

Thanks to the OpenAI and Superalignment Generalization Team's awesome work.

When I reading the code of vision part, I found a minor bug about CosineAnnealingLR. Since the learning rate schedule is set by n_epochs not n_iters,

schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)

the schedule.step() should be called outside train_loader loop, corespondingly:

    for epoch in (pbar := tqdm.tqdm(range(n_epochs), desc="Epoch 0")):
        correct, total = 0, 0
        for x, y in train_loader:
            x, y = x.cuda(), y.cuda()
            pred = model(x)
            loss = criterion(pred, y)
            schedule.step() # <-- remove
            if len(y.shape) > 1:
                y = torch.argmax(y, dim=1)
            correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()
            total += len(y)
        schedule.step() # <-- add
        pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}")

After fixing the logic, the final results should be like this:

Model Top-1 Accuracy Top-1 Acc (schedule outside)
AlexNet 56.6 -
Dino ResNet50 63.7 -
Dino ViT-B/8 74.9 -
AlexNet → DINO ResNet50 60.7 61.9 (+1.2)
AlexNet → DINO ViT-B/8 64.2 67.1 (+2.9)
wangguanan commented 10 months ago

One more thing, using multiple processings can significantly increase data loading speed, i.e. reduce training and inference time, which can be implemented by setting num_workers > 0, correspondingly:


def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM):
    ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform)
    loader =, shuffle=shuffle, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8)
    return ds, loader


train_loader =, shuffle=True, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8)
that's great, feel free to make PRs!