huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.1k stars 27.03k forks source link

Tuning generation_config in Trainer hyperparameter_search (Optuna backend) #33755

Open ribesstefano opened 1 month ago

ribesstefano commented 1 month ago

Feature request

Adding generation configurations to the parameters that can be tuned in a Trainer.

Motivation

When defining the Optuna hyper-parameter space, I would like to investigate whether or not different generation configurations can affect performance. For example, something as simple as: is beam search with groups better than standard beam search?

Example of implementation:

def optuna_hp_space(trial):
    # Define default generation parameters
    generation_params = {
        "max_length": 512,
        "max_new_tokens": 512,
        'top_k': 20,
    }

    # Define the generation strategies and pick one with Optuna
    # REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
    generation_strategy_params = {
        "greedy": {"num_beams": 1, "do_sample": False},
        "contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
        "multinomial_sampling": {"num_beams": 1, "do_sample": True},
        "beam_search_decoding": {"num_beams": 5, "do_sample": False},
        "beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True},
        "diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0},
    }
    gen_strategy = trial.suggest_categorical("generation_strategy", list(generation_strategy_params.keys()))
    generation_params.update(generation_strategy_params[gen_strategy])

    # Update the generation params with the temperature
    temperature = trial.suggest_float("temperature", 0.5, 1.1, log=False)
    generation_params["temperature"] = temperature

    # Instantiate a GenerationConfig object to pass to the Trainer arguments
    generation_config = GenerationConfig(**generation_params)

    # Setup learning rate warmup ratio
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.1, step=0.01)

    # Setup learning rate scheduler type and its fixed kwargs
    lr_scheduler_type = trial.suggest_categorical("lr_scheduler_type", ["cosine", "cosine_with_restarts", "reduce_lr_on_plateau"]) # "cosine_with_min_lr", "polynomial"
    if lr_scheduler_type == "cosine":
        lr_scheduler_kwargs = {}
    elif lr_scheduler_type == "cosine_with_restarts":
        lr_scheduler_kwargs = {"num_cycles": 5}
    elif lr_scheduler_type == "cosine_with_min_lr":
        lr_scheduler_kwargs = {"min_lr": 1e-6}
    elif lr_scheduler_type == "polynomial":
        lr_scheduler_kwargs = {"power": 1.0}
    elif lr_scheduler_type == "reduce_lr_on_plateau":
        lr_scheduler_kwargs = {"min_lr": 1e-6}

    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True),
        "lr_scheduler_type": lr_scheduler_type,
        "lr_scheduler_kwargs": lr_scheduler_kwargs,
        "warmup_ratio": warmup_ratio,
        # "generation_config": generation_params, # <-- BREAKING: PASSING THE KWARGS
        # "generation_config": generation_config, # <-- BREAKING: PASSING THE INSTANTIATED OBJECT
        # **{f"generation_{k}": v for k, v in generation_params.items()}, # <-- NOT BREAKING, BUT ORIGINAL VALUES ARE USED INSTEAD OF THESE
        **generation_params # <-- NOT BREAKING, BUT ORIGINAL VALUES ARE USED INSTEAD OF THESE
    }

Your contribution

Currently I'm experiencing the following error:

Traceback (most recent call last):
  File "/cephyr/users/ribes/Alvis/PROTAC-Splitter/src/train_model.py", line 18, in <module>
    CLI([train_model, train_ppo_model])
  File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/cephyr/users/ribes/Alvis/PROTAC-Splitter/protac_splitter/llms/training.py", line 277, in train_model
    best_trials = trainer.hyperparameter_search(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3217, in hyperparameter_search
    best_run = backend_obj.run(self, n_trials, direction, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/hyperparameter_search.py", line 72, in run
    return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 260, in run_hp_search_optuna
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/study.py", line 475, in optimize
    _optimize(
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 63, in _optimize
    _optimize_sequential(
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 160, in _optimize_sequential
    frozen_trial = _run_trial(study, func, catch)
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 248, in _run_trial
    raise func_err
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 247, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1889, in train
    self._hp_search_setup(trial)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1517, in _hp_search_setup
    value = type(old_attr)(value)
TypeError: GenerationConfig.__init__() takes 1 positional argument but 2 were given

Which makes me suspect that a single GenerationConfig object is created once for all trials. This is "in contrast" to the model instantiation, which must be a Callable, as specified in the documentation for the hyperparameter_search method.

LysandreJik commented 1 month ago

cc @gante regarding the generation config :)