fidelity / stoke

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.
https://fidelity.github.io/stoke/
Apache License 2.0
66 stars 3 forks source link

Add StokeScheduler #20

Open zaksemenov opened 2 years ago

zaksemenov commented 2 years ago

Feature

Add a class that wraps the torch.scheduler in the same way that the StokeOptimizer wraps the StokeScheduler

Motivation

Although a scheduler can be added and work in tandem with the StokeOptimizer, the stoke_obj.step() method seems to be wrapping a few other things, so it would add to code cleanliness for the end user if the scheduler.step was also encapsulated in the stoke_obj.step method

Proposal

Extend the API to wrap a torch.scheduler instance and call stoke_scheduler.step in the stoke_obj.step method (if instantiated)

ncilfone commented 2 years ago

Current way to handle LR Scheduler:

stoke_optimizer= StokeOptimizer(
     optimizer = AdamW,
     optimizer_kwargs = {
         "lr" : 1e-3,          
         "betas" : (0.9, 0.99),
         "eps" : 1e-8,
         "weight_decay" : 1e-4        
     }

 )

stoke_model = Stoke(model, stoke_optimizer.......)

scheduler = optim.lr_scheduler.OneCycleLR(stoke_model.optimizer, max_lr=0.001, pct_start = 0.9, steps_per_epoch=len(train_dataloader), epochs=epochs)

train():
     .......
      ### PyTorch 1.10 -- they changed the order required
      stoke_model.step()
      scheduler.step() 

     ### PyTorch < 1.10
     ......
     scheduler.step()
     stoke_model.step()