Open ErikHartman opened 2 months ago
Since LR_SchedulerInterface is just an interface, you don't need to call super().init in the implementation. The NotImplementedError occurs because it tries to initialize an interface class.
Additionally, two other changes are needed:
Sorry about the docstring incorrectly stating that the return value should be a float; it will be updated. This version of your scheduler should work fine:
class CustomReduceLROnPlateau(LR_SchedulerInterface):
def __init__(self, optimizer, num_warmup_steps, num_training_steps, patience=10, factor=0.1, mode='min', threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8, **kwargs):
self.scheduler = ReduceLROnPlateau(
optimizer,
mode=mode,
factor=factor,
patience=patience,
threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown,
min_lr=min_lr,
eps=eps,
**kwargs
)
def step(self, loss, epoch=None):
self.scheduler.step(loss, epoch)
def get_last_lr(self):
return [self.scheduler.optimizer.param_groups[0]['lr']]
For a very similar implementation, you can check out the scheduler used in AlphaDIA at Alphadia transfer-learning
Bug description Trying to set a custom learning rate scheduler with the
set_lr_scheduler_class
throws anNotImplemented
error.To Reproduce
Expected behavior No error.
Version (please complete the following information):
Additional context I see in the source code that this isn't implemented. Would be nice if it was.