Closed markschoene closed 8 months ago
Hi thanks for the feedback and suggestion! Agreed this is a simple fix that should resolve this.
At the time we wrote this code, Optax did not have a good cosine scheduler. But I would now recommend using the optax.warmup_cosine_decay_schedule as we have an example in the development branch here: https://github.com/lindermanlab/S5/blob/008bd547890a17d6fce059f5de104c0d578b101b/train_utils.py#L90
Hi, we got the pleasure to work with S5 in our research project. Most of the code works as expected. However, the cosine annealing lr schedule doesn't look like what I expected. Digging into the details, I found that the
step
count progresses already during the warmup, which results inargs.warmup_end
instead of 0num_epochs - args.warmup
The result is the strange curve shown in the image below. A quick fix would be to reset
step = 0
ifepoch == args.warmup_end
.