Open gmftbyGMFTBY opened 9 months ago
Good ISSUE!
Similar to the learning_rate_target
in trlx config: https://trlx.readthedocs.io/en/docs/configs.html#trlx.data.configs.TrainConfig
learning_rate_init (float) – Initial learning rate after ramp up
learning_rate_target (float) – Target learning rate after decay
Now the minimum learning rate cannot be configured in the transformers, it is hard-coded to 0, such as cosine_schedule or linear_schedule
https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L140
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L104
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
import math
import transformers
lr_decay_steps = 1500
min_lr_ratio = 0.1
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step > lr_decay_steps:
return min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(max(1, lr_decay_steps - num_warmup_steps))
coefficient = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return min_lr_ratio + coefficient * (1.0 - min_lr_ratio)
def add_lr_decay_limit_for_cosine_schedule():
transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda = _get_cosine_schedule_with_warmup_lr_lambda
Our implementation builds upon the cosine scheduler provided by transformers.optimization
. We have introduced two new parameters:
lr_decay_steps
: This parameter signifies the maximum number of iterations for learning rate decay.
min_lr_ratio
: This parameter represents the proportion of the constant learning rate in comparison to TrainingArguments.learning_rate
after reaching the specified lr_decay_steps
.
The practical impact of the aforementioned scheduler is illustrated in the following figure.
Alternatively, we can adopt the following implementation, which features fewer parameters and a slightly different curve.
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
cc @muellerzr sounds good no?
@muellerz Hello, could you please review our issue and contribution?
Feature request
We try to propose the addition of a new and widely-adopted scheduler strategy for language model pretraining in the Transformers repository. Upon reviewing the current schedulers available in the Transformers optimization module, it appears there is a notable absence of an out-of-the-box implementation for a specific type of scheduler. This particular scheduler is prevalent in recent pre-training models and features a warmup decay, but importantly, it also maintains a limited minimum learning rate post-maximum iteration steps.
This scheduling approach has seen extensive use in several prominent pre-trained large language models (LLMs), including:
The introduction of this scheduler into the Transformers library would not only complete the suite of existing scheduling strategies but also provide practitioners with a tool that's already proven its efficacy in recent LLM training methodologies. I believe its inclusion will be beneficial for the community, fostering more efficient and effective pretraining processes.
Motivation
This issue aims to introduce a novel scheduler into the current Transformers library. The proposed scheduler combines the elements of warmup decay with a distinctive feature - the implementation of a constrained minimum learning rate beyond the maximum iteration steps.
Your contribution
Yes, we could submit a PR as soon as possible if any huggingface members think this contribution is necessary.