Open crapthings opened 1 month ago
these code sugguest by claude 3.5 it can resume trainning, but i don't know if its working or not
from torch.optim.lr_scheduler import _LRScheduler
class PolyLr(_LRScheduler):
def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
self.gamma = gamma
self.max_iteration = max_iteration
self.minimum_lr = minimum_lr
self.warmup_iteration = warmup_iteration
# Remove these lines as they're handled by parent class
# self.last_epoch = None
# self.base_lrs = []
super(PolyLr, self).__init__(optimizer, last_epoch)
def poly_lr(self, base_lr, step):
# Ensure step doesn't exceed max_iteration to avoid negative values
step = min(float(step), self.max_iteration)
return (base_lr - self.minimum_lr) * (max(0, 1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
def warmup_lr(self, base_lr, alpha):
# Ensure alpha is between 0 and 1
alpha = max(0.0, min(1.0, float(alpha)))
return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
def get_lr(self):
if self.last_epoch < self.warmup_iteration:
alpha = self.last_epoch / self.warmup_iteration
lrs = [self.warmup_lr(base_lr, alpha) for base_lr in self.base_lrs]
else:
lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
return lrs
# from torch.optim.lr_scheduler import _LRScheduler
# class PolyLr(_LRScheduler):
# def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
# self.gamma = gamma
# self.max_iteration = max_iteration
# self.minimum_lr = minimum_lr
# self.warmup_iteration = warmup_iteration
# self.last_epoch = None
# self.base_lrs = []
# super(PolyLr, self).__init__(optimizer, last_epoch)
# def poly_lr(self, base_lr, step):
# return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
# def warmup_lr(self, base_lr, alpha):
# return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
# def get_lr(self):
# if self.last_epoch < self.warmup_iteration:
# alpha = self.last_epoch / self.warmup_iteration
# lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in
# self.base_lrs]
# else:
# lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
# return lrs