huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Apache License 2.0
135.69k stars 27.16k forks source link

WSD Scheduler to auto infer training steps #34905

Open wheynelau opened 5 days ago

wheynelau commented 5 days ago

Feature request

WSD Scheduler should calculate stable steps in And if num_warmup_steps is provided in kwargs, schedule_func should respect the kwargs.

My guess is that the intention is it to decay till min and stay there till the end of training, but min_lr_ratio is set to the default of 0, wouldn't the learning rate be always 0? Would like to have some insights on this if possible.

TypeError: get_wsd_schedule() missing 1 required positional argument: 'num_stable_steps'

Additionally, trying to pass in num_warmup_steps in lr_scheduler_kwargs will result in duplicate keys:

     return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
TypeError: transformers.optimization.get_wsd_schedule() got multiple values for keyword argument 'num_warmup_steps'


I want to run WSD scheduler for my training, but I do not want to have to calculate the stable steps.

Your contribution

I can contribute to this, but I would like to better understand the edge cases or possible scenarios I might have missed out from the maintainers. However, here is my current workaround:

def get_wsd_schedule(
    + num_training_steps: int = 0,

    assert num_stable_steps or num_training_steps, "One of either stable steps or training steps must be provided"
    if not num_stable_steps:
        num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
    if name == SchedulerType.WARMUP_STABLE_DECAY:
        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps,num_training_steps=num_training_steps, **scheduler_specific_kwargs)
Rocketknight1 commented 5 days ago

cc @muellerzr @sunmarc for Trainer