automl / jahs_bench_201

The first collection of surrogate benchmarks for Joint Architecture and Hyperparameter Search.
https://automl.github.io/jahs_bench_201/
MIT License
15 stars 7 forks source link

Issue with config and trajectory #5

Open eddiebergman opened 1 year ago

eddiebergman commented 1 year ago

Hiyo,

I ran into an issue while wrapping jahs bench with a given config. Using the trajectory functionality will fail but manually iterating over the epochs will success. I've attached the config, reproduce script, the stack trace and my full environemnt.

{
    'N': 3,
    'W': 4,
    'Op1': 1,
    'Op2': 3,
    'Op3': 4,
    'Op4': 1,
    'Op5': 2,
    'Op6': 3,
    'TrivialAugment': True,
    'Activation': 'Hardswish',
    'Optimizer': 'SGD',
    'Resolution': 0.25,
    'LearningRate': 0.10214993871440806,
    'WeightDecay': 0.00031212403229771485
}

Here's the reproducibility script:

from jahs_bench import Benchmark, BenchmarkTasks

config = {
    'N': 3,
    'W': 4,
    'Op1': 1,
    'Op2': 3,
    'Op3': 4,
    'Op4': 1,
    'Op5': 2,
    'Op6': 3,
    'TrivialAugment': True,
    'Activation': 'Hardswish',
    'Optimizer': 'SGD',
    'Resolution': 0.25,
    'LearningRate': 0.10214993871440806,
    'WeightDecay': 0.00031212403229771485
}

bench = Benchmark(
    task=BenchmarkTasks.FashionMNIST,
    save_dir="data/jahs-bench-data",
    download=False
)

# This will fail
traj = bench(config, nepochs=200, full_trajectory=True)

# This works
traj = {f: bench(config, nepochs=f)[f] for f in range(1, 201)}

Full trace:

TypeError                                 Traceback (most recent call last)
<ipython-input-7-2bdcd586c58f> in <module>
----> 1 traj = bench(config, nepochs=200, full_trajectory=True)

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/jahs_bench/api.py in __call__(self, config, nepochs, full_trajectory, **kwargs)
    138                  full_trajectory: bool = False, **kwargs):
    139         return self._call_fn(config=config, nepochs=nepochs,
--> 140                              full_trajectory=full_trajectory, **kwargs)
    141 
    142     def _benchmark_surrogate(self, config: dict, nepochs: Optional[int] = 200,

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/jahs_bench/api.py in _benchmark_surrogate(self, config, nepochs, full_trajectory, **kwargs)
    155         outputs = []
    156         for model in self._surrogates.values():
--> 157             outputs.append(model.predict(features))
    158 
    159         outputs: pd.DataFrame = pd.concat(outputs, axis=1)

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/jahs_bench/surrogate/model.py in predict(self, features)
    435 
    436         features = features.loc[:, self.feature_headers]
--> 437         ypredict = self.model.predict(features)
    438         ypredict = pd.DataFrame(ypredict, columns=self.label_headers)
    439         return ypredict

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
    111 
    112             # lambda, but not partial, allows help() to work with update_wrapper
--> 113             out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)  # noqa
    114         else:
    115 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
    467         Xt = X
    468         for _, name, transform in self._iter(with_final=False):
--> 469             Xt = transform.transform(Xt)
    470         return self.steps[-1][1].predict(Xt, **predict_params)
    471 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/compose/_column_transformer.py in transform(self, X)
    751             _transform_one,
    752             fitted=True,
--> 753             column_as_strings=fit_dataframe_and_transform_dataframe,
    754         )
    755         self._validate_output(Xs)

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/compose/_column_transformer.py in _fit_transform(self, X, y, func, fitted, column_as_strings)
    613                     message=self._log_message(name, idx, len(transformers)),
    614                 )
--> 615                 for idx, (name, trans, column, weight) in enumerate(transformers, 1)
    616             )
    617         except ValueError as e:

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
   1041             # remaining jobs.
   1042             self._iterating = False
-> 1043             if self.dispatch_one_batch(iterator):
   1044                 self._iterating = self._original_iterator is not None
   1045 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
    859                 return False
    860             else:
--> 861                 self._dispatch(tasks)
    862                 return True
    863 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/parallel.py in _dispatch(self, batch)
    777         with self._lock:
    778             job_idx = len(self._jobs)
--> 779             job = self._backend.apply_async(batch, callback=cb)
    780             # A job can complete so quickly than its callback is
    781             # called before we get here, causing self._jobs to

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/_parallel_backends.py in apply_async(self, func, callback)
    206     def apply_async(self, func, callback=None):
    207         """Schedule a func to be run"""
--> 208         result = ImmediateResult(func)
    209         if callback:
    210             callback(result)

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/_parallel_backends.py in __init__(self, batch)
    570         # Don't delay the application, to avoid keeping the input
    571         # arguments in memory
--> 572         self.results = batch()
    573 
    574     def get(self):

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/parallel.py in __call__(self)
    261         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    262             return [func(*args, **kwargs)
--> 263                     for func, args, kwargs in self.items]
    264 
    265     def __reduce__(self):

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/joblib/parallel.py in <listcomp>(.0)
    261         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    262             return [func(*args, **kwargs)
--> 263                     for func, args, kwargs in self.items]
    264 
    265     def __reduce__(self):

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/utils/fixes.py in __call__(self, *args, **kwargs)
    214     def __call__(self, *args, **kwargs):
    215         with config_context(**self.config):
--> 216             return self.function(*args, **kwargs)
    217 
    218 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/pipeline.py in _transform_one(transformer, X, y, weight, **fit_params)
    874 
    875 def _transform_one(transformer, X, y, weight, **fit_params):
--> 876     res = transformer.transform(X)
    877     # if we have a weight for this transformer, multiply output
    878     if weight is None:

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/preprocessing/_encoders.py in transform(self, X)
    511             handle_unknown=self.handle_unknown,
    512             force_all_finite="allow-nan",
--> 513             warn_on_unknown=warn_on_unknown,
    514         )
    515 

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/preprocessing/_encoders.py in _transform(self, X, handle_unknown, force_all_finite, warn_on_unknown)
    132         for i in range(n_features):
    133             Xi = X_list[i]
--> 134             diff, valid_mask = _check_unknown(Xi, self.categories_[i], return_mask=True)
    135 
    136             if not np.all(valid_mask):

~/code/mf-prior-bench/.venv/lib/python3.7/site-packages/sklearn/utils/_encode.py in _check_unknown(values, known_values, return_mask)
    259 
    260         # check for nans in the known_values
--> 261         if np.isnan(known_values).any():
    262             diff_is_nan = np.isnan(diff)
    263             if diff_is_nan.any():

TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

Env:

Python 3.7.12

Package                       Version
----------------------------- -----------
alabaster                     0.7.12
apeye                         1.2.0
argon2-cffi                   21.3.0
argon2-cffi-bindings          21.2.0
atomicwrites                  1.4.1
attrs                         22.1.0
autodocsumm                   0.2.9
automl-sphinx-theme           0.1.12
Babel                         2.10.3
backcall                      0.2.0
beautifulsoup4                4.11.1
black                         22.6.0
bleach                        5.0.1
CacheControl                  0.12.11
certifi                       2022.6.15
cffi                          1.15.1
cfgv                          3.3.1
charset-normalizer            2.1.0
click                         8.1.3
coloredlogs                   15.0.1
ConfigSpace                   0.4.21
coverage                      6.4.4
cssutils                      2.5.1
cycler                        0.11.0
Cython                        0.29.32
debugpy                       1.6.3
decopatch                     1.4.10
decorator                     5.1.1
defusedxml                    0.7.1
dict2css                      0.3.0
distlib                       0.3.5
docstring-to-markdown         0.10
docutils                      0.18.1
domdf-python-tools            3.3.0
entrypoints                   0.4
fastjsonschema                2.16.1
filelock                      3.8.0
flake8                        5.0.4
flatbuffers                   2.0
fonttools                     4.36.0
html5lib                      1.1
humanfriendly                 10.0
identify                      2.5.3
idna                          3.3
imagesize                     1.4.1
importlib-metadata            4.12.0
importlib-resources           5.9.0
ipykernel                     6.15.1
ipython                       7.34.0
ipython-genutils              0.2.0
ipywidgets                    8.0.1
isort                         5.10.1
jahs-bench                    1.0.0
jedi                          0.18.1
jedi-language-server          0.37.0
Jinja2                        3.1.2
joblib                        1.1.0
jsonschema                    4.13.0
jupyter                       1.0.0
jupyter-client                7.3.4
jupyter-console               6.4.4
jupyter-core                  4.11.1
jupyterlab-pygments           0.2.2
jupyterlab-widgets            3.0.2
kiwisolver                    1.4.4
lockfile                      0.12.2
lxml                          4.9.1
makefun                       1.14.0
markdown-it-py                2.1.0
MarkupSafe                    2.1.1
matplotlib                    3.5.3
matplotlib-inline             0.1.6
mccabe                        0.7.0
mdit-py-plugins               0.3.0
mdurl                         0.1.2
mf-prior-bench                0.1.0
mfp-bench                     0.0.1
mfpbench                      0.0.1
mistune                       0.8.4
more-itertools                8.14.0
mpmath                        1.2.1
msgpack                       1.0.4
mypy                          0.971
mypy-extensions               0.4.3
myst-parser                   0.18.0
natsort                       8.1.0
nbclient                      0.6.6
nbconvert                     6.5.3
nbformat                      5.4.0
nest-asyncio                  1.5.5
nodeenv                       1.7.0
notebook                      6.4.12
numpy                         1.21.6
numpydoc                      1.4.0
onnxruntime                   1.12.1
packaging                     21.3
pandas                        1.3.5
pandocfilters                 1.5.0
parso                         0.8.3
pathspec                      0.9.0
pexpect                       4.8.0
pickleshare                   0.7.5
Pillow                        9.2.0
pip                           22.2.2
pkgutil_resolve_name          1.3.10
platformdirs                  2.5.2
pluggy                        0.13.1
pre-commit                    2.20.0
prometheus-client             0.14.1
prompt-toolkit                3.0.30
protobuf                      4.21.5
psutil                        5.9.1
ptyprocess                    0.7.0
py                            1.11.0
pycodestyle                   2.9.1
pycparser                     2.21
pydantic                      1.9.2
pydocstyle                    6.1.1
pyflakes                      2.5.0
pygls                         0.12.1
Pygments                      2.13.0
pyparsing                     3.0.9
pyrsistent                    0.18.1
pytest                        4.6.0
pytest-cases                  3.6.13
pytest-cov                    3.0.0
python-dateutil               2.8.2
pytz                          2022.2.1
PyYAML                        6.0
pyzmq                         23.2.1
qtconsole                     5.3.1
QtPy                          2.2.0
requests                      2.28.1
ruamel.yaml                   0.17.21
ruamel.yaml.clib              0.2.6
scikit-learn                  1.0.2
scipy                         1.7.3
seaborn                       0.11.2
Send2Trash                    1.8.0
setuptools                    47.1.0
setuptools-scm                6.4.2
six                           1.16.0
snowballstemmer               2.2.0
soupsieve                     2.3.2.post1
Sphinx                        5.1.1
sphinx-autodoc-typehints      1.19.2
sphinx-gallery                0.11.0
sphinx-jinja2-compat          0.1.2
sphinx-prompt                 1.5.0
sphinx-tabs                   3.4.1
sphinx-toolbox                3.2.0
sphinxcontrib-applehelp       1.0.2
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.0
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
sympy                         1.10.1
tabulate                      0.8.10
terminado                     0.15.0
threadpoolctl                 3.1.0
tinycss2                      1.1.1
toml                          0.10.2
tomli                         2.0.1
tornado                       6.2
traitlets                     5.3.0
typed-ast                     1.5.4
typeguard                     2.13.3
typing_extensions             4.3.0
typing-inspect                0.8.0
urllib3                       1.26.11
virtualenv                    20.16.3
wcwidth                       0.2.5
webencodings                  0.5.1
wheel                         0.37.1
widgetsnbextension            4.0.2
xgboost                       1.5.2
yacs                          0.1.8
yahpo-gym                     1.0.1
zipp                          3.8.1
eddiebergman commented 1 year ago

Update: This seems to happen with almost any config on fashion-mnist