Adding generation configurations to the parameters that can be tuned in a Trainer.
Motivation
When defining the Optuna hyper-parameter space, I would like to investigate whether or not different generation configurations can affect performance. For example, something as simple as: is beam search with groups better than standard beam search?
Example of implementation:
def optuna_hp_space(trial):
# Define default generation parameters
generation_params = {
"max_length": 512,
"max_new_tokens": 512,
'top_k': 20,
}
# Define the generation strategies and pick one with Optuna
# REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
generation_strategy_params = {
"greedy": {"num_beams": 1, "do_sample": False},
"contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
"multinomial_sampling": {"num_beams": 1, "do_sample": True},
"beam_search_decoding": {"num_beams": 5, "do_sample": False},
"beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True},
"diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0},
}
gen_strategy = trial.suggest_categorical("generation_strategy", list(generation_strategy_params.keys()))
generation_params.update(generation_strategy_params[gen_strategy])
# Update the generation params with the temperature
temperature = trial.suggest_float("temperature", 0.5, 1.1, log=False)
generation_params["temperature"] = temperature
# Instantiate a GenerationConfig object to pass to the Trainer arguments
generation_config = GenerationConfig(**generation_params)
# Setup learning rate warmup ratio
warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.1, step=0.01)
# Setup learning rate scheduler type and its fixed kwargs
lr_scheduler_type = trial.suggest_categorical("lr_scheduler_type", ["cosine", "cosine_with_restarts", "reduce_lr_on_plateau"]) # "cosine_with_min_lr", "polynomial"
if lr_scheduler_type == "cosine":
lr_scheduler_kwargs = {}
elif lr_scheduler_type == "cosine_with_restarts":
lr_scheduler_kwargs = {"num_cycles": 5}
elif lr_scheduler_type == "cosine_with_min_lr":
lr_scheduler_kwargs = {"min_lr": 1e-6}
elif lr_scheduler_type == "polynomial":
lr_scheduler_kwargs = {"power": 1.0}
elif lr_scheduler_type == "reduce_lr_on_plateau":
lr_scheduler_kwargs = {"min_lr": 1e-6}
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True),
"lr_scheduler_type": lr_scheduler_type,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
"warmup_ratio": warmup_ratio,
# "generation_config": generation_params, # <-- BREAKING: PASSING THE KWARGS
# "generation_config": generation_config, # <-- BREAKING: PASSING THE INSTANTIATED OBJECT
# **{f"generation_{k}": v for k, v in generation_params.items()}, # <-- NOT BREAKING, BUT ORIGINAL VALUES ARE USED INSTEAD OF THESE
**generation_params # <-- NOT BREAKING, BUT ORIGINAL VALUES ARE USED INSTEAD OF THESE
}
Your contribution
Currently I'm experiencing the following error:
Traceback (most recent call last):
File "/cephyr/users/ribes/Alvis/PROTAC-Splitter/src/train_model.py", line 18, in <module>
CLI([train_model, train_ppo_model])
File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
return _run_component(component, init.get(subcommand))
File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
return component(**cfg)
File "/cephyr/users/ribes/Alvis/PROTAC-Splitter/protac_splitter/llms/training.py", line 277, in train_model
best_trials = trainer.hyperparameter_search(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3217, in hyperparameter_search
best_run = backend_obj.run(self, n_trials, direction, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/hyperparameter_search.py", line 72, in run
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 260, in run_hp_search_optuna
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
File "/opt/conda/lib/python3.10/site-packages/optuna/study/study.py", line 475, in optimize
_optimize(
File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 63, in _optimize
_optimize_sequential(
File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 160, in _optimize_sequential
frozen_trial = _run_trial(study, func, catch)
File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 248, in _run_trial
raise func_err
File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
value_or_values = func(trial)
File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 247, in _objective
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1889, in train
self._hp_search_setup(trial)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1517, in _hp_search_setup
value = type(old_attr)(value)
TypeError: GenerationConfig.__init__() takes 1 positional argument but 2 were given
Which makes me suspect that a singleGenerationConfig object is created once for all trials. This is "in contrast" to the model instantiation, which must be a Callable, as specified in the documentation for the hyperparameter_search method.
Feature request
Adding generation configurations to the parameters that can be tuned in a
Trainer
.Motivation
When defining the Optuna hyper-parameter space, I would like to investigate whether or not different generation configurations can affect performance. For example, something as simple as: is beam search with groups better than standard beam search?
Example of implementation:
Your contribution
Currently I'm experiencing the following error:
Which makes me suspect that a single
GenerationConfig
object is created once for all trials. This is "in contrast" to the model instantiation, which must be aCallable
, as specified in the documentation for thehyperparameter_search
method.