When I make a run with cfg.dataset.task_type = 'regression', the code crashes at the end of the run. The error message is:
Traceback (most recent call last):
File "main_pyg.py", line 55, in <module>
agg_runs(cfg.out_dir, cfg.metric_best)
File "~/Code/GraphGym/graphgym/utils/agg_runs.py", line 100, in agg_runs
[stats[metric] for stats in stats_list])
File "~/Code/GraphGym/graphgym/utils/agg_runs.py", line 100, in <listcomp>
[stats[metric] for stats in stats_list])
KeyError: 'accuracy'
The problem seems to be that accuracy is not a metric logged for regression tasks. Here are the relevant lines in agg_runs.py:
if metric_best == 'auto':
metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'
Here's a fix:
if metric_best == 'auto':
if cfg.dataset.task_type == 'classification':
metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'
elif cfg.dataset.task_type == 'regression':
metric = 'mse'
In order to perform the regression, it will also be necessary to add metric_agg: argmin inside the yaml file in order to get the desired results with MSE.
When I make a run with
cfg.dataset.task_type = 'regression'
, the code crashes at the end of the run. The error message is:The problem seems to be that
accuracy
is not a metric logged for regression tasks. Here are the relevant lines inagg_runs.py
:Here's a fix: