facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.91k stars 65 forks source link

Regular AdamW scores just as well on MNIST? + Feature request... #27

Closed drscotthawley closed 6 months ago

drscotthawley commented 6 months ago

Dear authors, thank you very much for releasing your code. I'm looking forward to using it to achieve better training results on a variety of problems.

Except that.. for the MNIST example, I find that that if I replace ScheduleFree with either ordinary AdamW (with no scheduling) or AdamW + Cosine Annealing, then... it score about the same as ScheduleFree on the MNIST example, for a variety of learning rates. ...And I don't just mean the final scores/states, I mean throughout the whole training sequence.

It's possible I did something wrong, however, all I did was run your example, as is, and then also make a version where I removed schedulefree and just went with an ordinary AdamW with and then without cosine schedule... and losses and accuracy scores are generally comparable, or even better sometimes compared to the results with schedulefree (e.g. 99.3% accuracy for AdamW with no schedule vs 99.2% for ScheduleFree). I'll include a diff of my "comparable AdamW" code below just to clarity what I did.

...Any comments on that? (This is not a "challenge", this is "I really do want to understand and to take advantage of this".)

I tried learning rates of 0.05, 0.001, and your default rate, and... I wasn't trying to fine-tune the learning rate or number of steps or anything. Perhaps MNIST Is just too "easy" of a problem to really showcase the improvements offered by ScheduleFree? Perhaps the later examples you plan to include will be more demonstrative.

Also, a Feature Request: in your subsequent examples: Might you perhaps make it so that it's easy to switch out Schedule Free with other methods, to better demonstrate its effectiveness? Maybe via a CLI argument that defaults to choosing your method, but that can otherwise choose AdamW for example?

Here's a diff of my non-schedulefree example code. It's just your main.py with a few lines changed:

$ diff main.py adam.py
37c37
< def train(args, model, device, train_loader, optimizer, epoch):
---
> def train(args, model, device, train_loader, optimizer, scheduler, epoch):
39c39
<     optimizer.train()
---
>     #optimizer.train()
48c48
<             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
---
>             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, \tlr = {:,.3g}'.format(
50c50
<                 100. * batch_idx / len(train_loader), loss.item()))
---
>                 100. * batch_idx / len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
52a53
>         scheduler.step()
57c58
<     optimizer.eval()
---
>     #optimizer.eval()
132c133,138
<     optimizer = schedulefree.AdamWScheduleFree(model.parameters(), lr=args.lr)
---
>    # optimizer = schedulefree.AdamWScheduleFree(model.parameters(), lr=args.lr)
>     optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
> 
>     steps = int(60000*args.epochs/args.batch_size)
>     print("steps = ",steps)
>     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps)
135c141
<         train(args, model, device, train_loader, optimizer, epoch)
---
>         train(args, model, device, train_loader, optimizer, scheduler, epoch)
adefazio commented 6 months ago

Yes I'm aware of this. The MNIST example is just to show how to use the optimizer. MNIST is a terrible test for optimizers, it's too easy a problem, basically everything works on it. We are looking at adding some more self-contained examples, we just haven't had the bandwidth yet.

drscotthawley commented 6 months ago

Thanks for your reply! I understand. Looking forward to learning more.