facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.91k stars 65 forks source link

Tracking effective learning rate #6

Closed drhead closed 7 months ago

drhead commented 7 months ago

I've been trying this optimizer out lately and I am very happy with the results I am getting so far. I am, however, wanting to be able to track what the model is doing more closely, and there doesn't seem to be an obvious way to track what the effective current LR is for a given parameter group (equivalent to d*lr on DAdaptation for example), if something that directly parallels that exists. My best guess is trying to derive ckp1 from accessible values, but I'm not sure if that really is the direct equivalent or even close. Recommendations for other useful metrics from the optimizer to track would also be helpful.

staghado commented 7 months ago

I have been trying to do a similar thing too; I was tracking the max lr and the stepsize or effective lr.

Here is a code snippet:

import math
import matplotlib.pyplot as plt

# hyperparameters
base_lr = 0.0025
lr_max = -1
beta1, beta2 = 0.9, 0.999
r = 0
weight_lr_power = 2

def lr_scheduler(k, lr, lr_max, weight_sum):

    # current lr
    lr = base_lr * math.sqrt(1 - beta2 ** (k + 1))

    # max lr
    lr_max = max(lr, lr_max) 

    # lr weight for current step k
    weight = ((k + 1) ** r) * lr_max ** weight_lr_power

    # print(f"{weight=}")
    # print(f"{lr_max=}")
    # print(f"{lr=}")

    # cumulative sum of weights
    weight_sum += weight

    # effective lr
    lr = lr * (beta1 * (1 - weight / weight_sum) - 1)

    return lr, lr_max, weight, weight_sum

steps = range(0, 100, 1)
lrs = []
lrs_max = []
weights = []

weight_sum = 0
lr = base_lr
lr_max = -1

for k in steps:
    lr, lr_max, weight, weight_sum = lr_scheduler(k, lr, lr_max, weight_sum)
    lrs.append(lr)
    lrs_max.append(lr_max)
    weights.append(weight)

plt.plot(steps, lrs, label='lr')
plt.plot(steps, lrs_max, label='max lr')
plt.legend()
adefazio commented 7 months ago

There really isn't an "effective" step-size. The way that averaging replaces a schedule is by behaving "as-if" you can used a linear-decay schedule up to the current point in time, ending right at the current point. It's like your always at the end of a schedule.