facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.9k stars 64 forks source link

Add foreach to optimizers #10

Closed drhead closed 7 months ago

drhead commented 7 months ago

I rewrote adamw_schedulefree to support using torch._foreach operations as an option. According to my profiler, this made the optimizer step go from taking ~85ms to ~65ms (the model used was the SD1.5 Unet). In theory, this should also increase peak memory usage during the optimizer step (according to Pytorch's documentation on the built-in optimizers) -- I have left it as enabled by default regardless since this is the default behavior for built-in optimizers.

I think it would obviously be best to add this to all included optimizers, but before I do that I would like to hear some feedback on code style to make sure everything I add is up to standards.

adefazio commented 7 months ago

Nice! I will take a look on Monday.

adefazio commented 7 months ago

Have you tried running https://github.com/facebookresearch/schedule_free/blob/main/schedulefree/test_schedulefree.py to make sure it gives identical outputs to the closure version? That will help verify if it's correct.

drhead commented 7 months ago

I implemented it in all of the main optimizers, both closure and non-closure versions. I skipped some of the renaming in the foreach versions for the sake of simplicity on the in-place operations. The code passes the test suite just fine in its current state.

adefazio commented 7 months ago

Ok, we will look this over during the next few days, sorry for the delay things are hectic at the moment.

adefazio commented 7 months ago

This looks to be working correctly, thanks!