Eladlev / AutoPrompt

A framework for prompt tuning using Intent-based Prompt Calibration
Apache License 2.0
1.86k stars 149 forks source link

Classification metrics can't handle a mix of unknown and multiclass targets #32

Closed alena-m closed 4 months ago

alena-m commented 4 months ago

Hi, could you provide examples how should look like config_ranking.yml for generation tasks? With default config_ranking.yml (as of now in main branch) it fails with:

Traceback (most recent call last):
  File "path_to_repo/AutoPrompt/run_generation_pipeline.py", line 64, in <module>
    best_prompt = ranker_pipeline.run_pipeline(opt.num_ranker_steps)
  File "path_to_repo/AutoPrompt/optimization_pipeline.py", line 272, in run_pipeline
    stop_criteria = self.step()
  File "path_to_repo/AutoPrompt/optimization_pipeline.py", line 252, in step
    self.eval.add_history(self.cur_prompt, self.task_description)
  File "path_to_repo/AutoPrompt/eval/evaluator.py", line 112, in add_history
    conf_matrix = confusion_matrix(self.dataset['annotation'],
  File "path_to_env/miniconda3/envs/AutoPrompt/lib/python3.10/site-packages/sklearn/metrics/_classification.py", line 317, in confusion_matrix
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "path_to_env/miniconda3/envs/AutoPrompt/lib/python3.10/site-packages/sklearn/metrics/_classification.py", line 95, in _check_targets
    raise ValueError(
ValueError: Classification metrics can't handle a mix of unknown and multiclass targets

If I add eval section to config_ranking.yml with function_name=ranking:

eval:
    function_name: 'ranking'
    error_threshold: 4

then it fails with:

Traceback (most recent call last):
  File "path_to_repo/AutoPrompt/run_generation_pipeline.py", line 53, in <module>
    ranker_pipeline = OptimizationPipeline(ranker_config_params, output_path=os.path.join(opt.output_dump, 'ranker'))
  File "path_to_repo/AutoPrompt/optimization_pipeline.py", line 58, in __init__
    self.eval = Eval(config.eval, self.meta_chain.error_analysis, self.dataset.label_schema)
  File "path_to_repo/AutoPrompt/eval/evaluator.py", line 19, in __init__
    self.score_func = self.get_eval_function(config)
  File "path_to_repo/AutoPrompt/eval/evaluator.py", line 39, in get_eval_function
    return utils.set_ranking_function(config.function_params)
AttributeError: 'EasyDict' object has no attribute 'function_params'. Did you mean: 'function_name'?
Eladlev commented 4 months ago

Hi, The default config files for the generation (as appears in run_generation_pipeline.py) should work well. The first failure that you get ('mix of unknown and multiclass targets'), is due to an issue when saving and loading the dataset (it converts the annotation ranks from strings to floats). I will upload in the coming days a fix for this bug. Meanwhile, if you are not loading dumps (load_dump should be an empty string) it should work well.