ray-project / tune-sklearn

A drop-in replacement for Scikit-Learn’s GridSearchCV / RandomizedSearchCV -- but with cutting edge hyperparameter tuning techniques.
https://docs.ray.io/en/master/tune/api_docs/sklearn.html
Apache License 2.0
465 stars 52 forks source link

Support `name` (Ray Tune kwarg) #205

Closed bohniti closed 3 years ago

bohniti commented 3 years ago

Dear all,

I get the following exception if I try to use the keyword name within a function call of TuneSearchCV:

image I currently run the latest version of ray, sklearn, and optune, and my function call looks as follows:

image

I noticed the following:

I installed two times sklearn-tune via the following command:

pip install -U git+https://github.com/ray-project/tune-sklearn.git && pip install 'ray[tune]'

If I do so and execute a very similar script on macOS, it works. Anyways within my docker container (ubuntu), it's not working.

I will add the coresponding docker file, a complete list of my versions and the whole script below:

Docker-file

FROM some_private_company_docker_repo//templates/docker_images/databricks-docker-ubuntu-18-04-dbr-7-x/image

# Install git/make & our python package

WORKDIR /usr/app
COPY . /usr/app

RUN apt-get update \
&& apt-get install -y vim \
&& apt-get install -y --no-install-recommends git \
&& apt-get install -y --no-install-recommends make \
&& apt-get purge -y --auto-remove \
&& rm -rf /var/lib/apt/lists/*

RUN /databricks/conda/envs/dcs-minimal/bin/pip install -r requirements-dev.txt \
&& /databricks/conda/envs/dcs-minimal/bin/pip install  -U git+https://github.com/ray-project/tune-sklearn.git \
&& /databricks/conda/envs/dcs-minimal/bin/pip install ray[tune] \
&& /databricks/conda/envs/dcs-minimal/bin/pip install -e . \
&& ln -s /databricks/conda/envs/dcs-minimal/bin/pip pip

Versions

Package  | installed version  | Latest version
Babel | 2.9.1 | 2.9.1
-- | -- | --
Jinja2 | 2.11.3 | 2.11.3
Mako | 1.1.4 | 1.1.4
MarkupSafe | 1.1.1 | 1.1.1
Pillow | 8.2.0 | 8.2.0
PyYAML | 5.4.1 | 5.4.1
Pygments | 2.7.4 | 2.8.1
SQLAlchemy | 1.4.11 | 1.4.11
Sphinx | 3.5.4 | 3.5.4
Unidecode | 1.2.0 | 1.2.0
aiohttp | 3.7.4.post0 | 3.7.4.post0
aiohttp-cors | 0.7.0 |  
aioredis | 1.3.1 | 1.3.1
alabaster | 0.7.12 | 0.7.12
alembic | 1.5.8 | 1.5.8
astroid | 2.5.6 | 2.5.6
async-timeout | 3.0.1 | 3.0.1
athena | 0.1.dev58+g0361c6a | 0.8.0
attrs | 20.3.0 | 20.3.0
backcall | 0.2.0 | 0.2.0
blessings | 1.7 | 1.7
cached-property | 1.5.2 | 1.5.2
cachetools | 4.2.2 | 4.2.2
certifi | 2020.12.5 | 2020.12.5
chardet | 4.0.0 | 4.0.0
click | 7.1.2 | 7.1.2
cliff | 3.7.0 | 3.7.0
cloudpickle | 1.6.0 | 1.6.0
cmaes | 0.8.2 | 0.8.2
cmd2 | 1.5.0 | 1.5.0
colorama | 0.4.4 | 0.4.4
colorlog | 5.0.1 | 5.0.1
coverage | 5.5 | 5.5
cycler | 0.10.0 | 0.10.0
decorator | 4.4.2 | 5.0.7
docopt | 0.6.2 | 0.6.2
docutils | 0.16 | 0.17.1
filelock | 3.0.12 | 3.0.12
flake8 | 3.9.1 | 3.9.1
future | 0.18.2 | 0.18.2
google-api-core | 1.26.3 | 1.26.3
google-auth | 1.30.0 | 1.30.0
googleapis-common-protos | 1.53.0 | 1.53.0
gpustat | 0.6.0 | 0.6.0
greenlet | 1.0.0 | 1.0.0
grpcio | 1.37.0 | 1.37.0
h5py | 3.2.1 | 3.2.1
hiredis | 2.0.0 | 2.0.0
hyperopt | 0.2.5 | 0.2.5
idna | 2.10 | 3.1
imagesize | 1.2.0 | 1.2.0
importlib-metadata | 4.0.1 | 4.0.1
iniconfig | 1.1.1 | 1.1.1
ipython | 7.4.0 | 7.22.0
ipython-genutils | 0.2.0 |  
isort | 5.8.0 | 5.8.0
jedi | 0.17.0 | 0.18.0
joblib | 1.0.1 | 1.0.1
jsonschema | 3.2.0 | 3.2.0
kiwisolver | 1.3.1 | 1.3.1
lazy-object-proxy | 1.6.0 | 1.6.0
matplotlib | 3.4.1 | 3.4.1
mccabe | 0.6.1 | 0.6.1
msgpack | 1.0.2 | 1.0.2
multidict | 5.1.0 | 5.1.0
networkx | 2.5.1 | 2.5.1
numpy | 1.20.2 | 1.20.2
numpydoc | 1.1.0 | 1.1.0
nvidia-ml-py3 | 7.352.0 | 7.352.0
opencensus | 0.7.12 | 0.7.12
opencensus-context | 0.1.2 | 0.1.2
optuna | 2.7.0 | 2.7.0
packaging | 20.9 | 20.9
pandas | 0.24.2 | 1.2.4
parso | 0.8.1 | 0.8.2
pbr | 5.6.0 | 5.6.0
pexpect | 4.8.0 | 4.8.0
pickleshare | 0.7.5 | 0.7.5
pip | 20.3.3 | 21.1
pluggy | 0.13.1 | 0.13.1
pockets | 0.9.1 | 0.9.1
prettytable | 2.1.0 | 2.1.0
prometheus-client | 0.10.1 | 0.10.1
prompt-toolkit | 2.0.10 | 3.0.18
protobuf | 3.15.8 | 3.15.8
psutil | 5.8.0 | 5.8.0
ptyprocess | 0.7.0 | 0.7.0
py | 1.10.0 | 1.10.0
py-spy | 0.3.5 | 0.3.5
pyarrow | 0.13.0 | 4.0.0
pyasn1 | 0.4.8 | 0.4.8
pyasn1-modules | 0.2.8 | 0.2.8
pycodestyle | 2.7.0 | 2.7.0
pydocstyle | 6.0.0 | 6.0.0
pyflakes | 2.3.1 | 2.3.1
pylint | 2.8.2 | 2.8.2
pyparsing | 2.4.7 | 2.4.7
pyperclip | 1.8.2 | 1.8.2
pyrsistent | 0.17.3 | 0.17.3
pytest | 6.2.3 | 6.2.3
pytest-cov | 2.11.1 | 2.11.1
pytest-flake8 | 1.0.7 | 1.0.7
pytest-pylint | 0.18.0 | 0.18.0
pytest-watch | 4.2.0 | 4.2.0
python-dateutil | 2.8.1 | 2.8.1
python-editor | 1.0.4 | 1.0.4
pytz | 2021.1 | 2021.1
ray | 1.3.0 | 1.3.0
redis | 3.5.3 | 3.5.3
requests | 2.25.1 | 2.25.1
rsa | 4.7.2 | 4.7.2
scikit-learn | 0.24.2 | 0.24.2
scipy | 1.6.3 | 1.6.3
setuptools | 52.0.0.post20210125 | 56.0.0
setuptools-scm | 6.0.1 | 6.0.1
six | 1.15.0 | 1.15.0
snowballstemmer | 2.1.0 | 2.1.0
sphinx-autoapi | 1.8.1 | 1.8.1
sphinx-rtd-theme | 0.5.2 | 0.5.2
sphinxcontrib-applehelp | 1.0.2 | 1.0.2
sphinxcontrib-devhelp | 1.0.2 | 1.0.2
sphinxcontrib-htmlhelp | 1.0.3 | 1.0.3
sphinxcontrib-jsmath | 1.0.1 | 1.0.1
sphinxcontrib-napoleon | 0.7 | 0.7
sphinxcontrib-programoutput | 0.17 | 0.17
sphinxcontrib-qthelp | 1.0.3 | 1.0.3
sphinxcontrib-serializinghtml | 1.1.4 | 1.1.4
sphinxcontrib-websupport | 1.2.4 | 1.2.4
stevedore | 3.3.0 | 3.3.0
tabulate | 0.8.9 | 0.8.9
tensorboardX | 2.2 | 2.2
threadpoolctl | 2.1.0 | 2.1.0
toml | 0.10.2 | 0.10.2
tqdm | 4.60.0 | 4.60.0
traitlets | 4.3.3 | 5.0.5
tune-sklearn | 0.2.1 | 0.2.1
typed-ast | 1.4.3 | 1.4.3
typing-extensions | 3.7.4.3 | 3.7.4.3
urllib3 | 1.26.4 | 1.26.4
watchdog | 2.0.3 | 2.0.3
wcwidth | 0.2.5 | 0.2.5
wheel | 0.36.2 | 0.36.2
wrapt | 1.12.1 | 1.12.1
yapf | 0.31.0 |  
yarl | 1.6.3 |  
zipp | 3.4.1 |  

### Script:****

#!usr/bin/env python
# -*- coding: utf-8 -*-
"""
TODO
"""

from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import SGDClassifier, ElasticNet
from tune_sklearn import TuneSearchCV

def fit_tune(info_blob, cfg):
    # Select model from config
    if cfg.skl_model == 'SGDClassifier':
        clf = SGDClassifier()
    elif cfg.skl_model == 'RandomForestClassifier':
        clf = RandomForestClassifier()
    elif cfg.skl_model == 'ElasticNet':
        clf = ElasticNet()
    else:
        raise ValueError(f'Sklearn model "{cfg.skl_model}" (skl_model) is not supported.')

    if cfg.skl_split_method == 'KFold':
        cv = KFold(n_splits=cfg.skl_n_folds, shuffle=True)
    else:
        raise ValueError(f'Sklearn cv split method "{cfg.skl_split_method}" (skl_split_method) is not supported.')

    # Set training and validation sets
    X, y = info_blob[cfg.key_x_values], info_blob[cfg.key_y_values]

    param_dists = cfg.skl_model_params

    tune_search = TuneSearchCV(clf,
                               param_distributions=param_dists,
                               n_trials=2,
                               max_iters=10,
                               local_dir="../../examples/full_example/ray_results",
                               early_stopping=True,
                               cv=cv,
                               search_optimization='optuna',
                               loggers=['mlflow'],
                               name='RandomForestClassifierExperiment')

    tune_search.fit(X, y)

def fit_cv(info_blob, cfg):
    """
    TODO
    Need to wrap with ray tune, if the user wants hyperparam opt
    Do with tune.run as in
    https://medium.com/optuna/scaling-up-optuna-with-ray-tune-88f6ca87b8c7 ?
    https://docs.ray.io/en/latest/tune/api_docs/suggestion.html ?

    Returns
    -------

    """

    print(cfg.skl_model_params)

    if cfg.skl_model == 'RandomForestClassifier':
        clf = RandomForestClassifier(**cfg.skl_model_params)
    else:
        raise ValueError(f'Sklearn model "{cfg.skl_model}" (skl_model) is not supported.')

    if cfg.skl_split_method == 'KFold':
        cv = KFold(n_splits=cfg.skl_n_folds, shuffle=True)
    else:
        raise ValueError(f'Sklearn cv split method "{cfg.skl_split_method}" (skl_split_method) is not supported.')

    xs, ys = info_blob[cfg.key_x_values], info_blob[cfg.key_y_values]
    for (ind_train, ind_val), i in zip(cv.split(xs), range(cfg.skl_n_splits)):
        clf.fit(xs[ind_train], ys[ind_train])

        # y_pred = clf.predict(xs[ind_val])
        # to be able to make automatic 1D prob (or score) distribution plots
        y_pred_prob = clf.predict_proba(xs[ind_val])  # TODO always predict proba and do y_pred by ourselfs, is faster

        # TODO
        # calculate (also custom user-defined) metrics
        # collect via log mechanism, also log via mlflow
        # plot metrics
        # save model to output folder structure (-> still need to implement output folder structure, cf. orcanet code)
        # implement hyperparams as input to func & support tuning via ray tune
bohniti commented 3 years ago

Sorry, wrong headline. The headline was attended to be in stackoverflow.

richardliaw commented 3 years ago

Hmm, why are you trying to specify name?

bohniti commented 3 years ago

I can think of many reasons. Right now, to get the results in a dashboard without specifying the results folder any time. Later, in production, this isn't necessary anymore. Hence, maybe not a high priority but still, since it is a variable, it shall be useable, right? :)

richardliaw commented 3 years ago

This should be supported in latest release actually.