facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.36k stars 306 forks source link

Tutorial fails #2516

Closed KantiCodes closed 3 months ago

KantiCodes commented 3 months ago

Edit

Dupe of ray's issue - https://github.com/ray-project/ray/issues/45720 Perhaps it's nice you know of it - if not feel free to close but the tutorial will still be outdated.

Problem

https://ax.dev/versions/0.1.9/tutorials/raytune_pytorch_cnn.html fails, the function's create_experiment signature does not match on the ray's side. It seems that create_experiment used to take objective_name but now takes objectives

I downloaded the source code from the link above (I will past it below). It does it not work since the ray.tune.suggest does not exist so I changed it to ray.tune.search I also deleted the tracking piece of code ray.tune.track.

Traceback

Traceback (most recent call last):
  File "/home/bartek/Projects/GNN-planning-paper/zzzz.py", line 55, in <module>
    ax.create_experiment(
TypeError: AxClient.create_experiment() got an unexpected keyword argument 'objective_name'

Python

3.10.12

Pip Freeze

aiohttp==3.9.5
aiosignal==1.3.1
alembic==1.13.1
aniso8601==9.0.1
annotated-types==0.7.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
ax-platform==0.4.0
blinker==1.8.2
botorch==0.11.0
build==1.2.1
cachetools==5.3.3
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
contourpy==1.2.1
cycler==0.12.1
decorator==5.1.1
Deprecated==1.2.14
docker==7.1.0
entrypoints==0.4
exceptiongroup==1.2.1
executing==2.0.1
filelock==3.14.0
Flask==3.0.3
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.0
gitdb==4.0.11
GitPython==3.1.43
gpytorch==1.11
graphene==3.3
graphql-core==3.2.3
graphql-relay==3.2.0
greenlet==3.0.3
gunicorn==22.0.0
idna==3.7
importlib_metadata==7.1.0
ipython==8.25.0
ipywidgets==8.1.3
itsdangerous==2.2.0
jaxtyping==0.2.29
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyterlab_widgets==3.0.11
kiwisolver==1.4.5
linear-operator==0.5.1
Mako==1.3.5
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
mlflow==2.13.1
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multipledispatch==1.0.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
opentelemetry-api==1.25.0
opentelemetry-sdk==1.25.0
opentelemetry-semantic-conventions==0.46b0
opt-einsum==3.3.0
packaging==24.0
pandas==2.2.2
parso==0.8.4
pexpect==4.9.0
pillow==10.3.0
pip-tools==7.4.1
plotly==5.22.0
prompt_toolkit==3.0.47
protobuf==4.25.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==15.0.2
pydantic==2.7.3
pydantic_core==2.18.4
pyg-lib==0.4.0+pt23cu121
Pygments==2.18.0
pyparsing==3.1.2
pyproject_hooks==1.1.0
pyre-extensions==0.0.30
pyro-api==0.1.2
pyro-ppl==1.9.1
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
querystring-parser==1.2.4
ray==2.23.0
referencing==0.35.1
requests==2.32.3
rpds-py==0.18.1
scikit-learn==1.5.0
scipy==1.13.1
seaborn==0.13.2
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.30
sqlparse==0.5.0
stack-data==0.6.3
sympy==1.12.1
tenacity==8.3.0
tensorboardX==2.6.2.2
threadpoolctl==3.5.0
tomli==2.0.1
torch==2.3.1
torch_cluster==1.6.3+pt23cu121
torch_geometric==2.5.3
torch_scatter==2.1.2+pt23cu121
torch_sparse==0.6.18+pt23cu121
torch_spline_conv==1.2.2+pt23cu121
torchaudio==2.3.1
torcheval==0.0.7
torchvision==0.18.1
tqdm==4.66.4
traitlets==5.14.3
triton==2.3.1
typeguard==2.13.3
typing-inspect==0.9.0
typing_extensions==4.12.1
tzdata==2024.1
urllib3==2.2.1
wcwidth==0.2.13
Werkzeug==3.0.3
widgetsnbextension==4.0.11
wrapt==1.16.0
yarl==1.9.4
zipp==3.19.2

Code from tutorial

#!/usr/bin/env python
# coding: utf-8

# # Ax Service API with RayTune on PyTorch CNN
# 
# Ax integrates easily with different scheduling frameworks and distributed training frameworks. In this example, Ax-driven optimization is executed in a distributed fashion using [RayTune](https://ray.readthedocs.io/en/latest/tune.html). 
# 
# RayTune is a scalable framework for hyperparameter tuning that provides many state-of-the-art hyperparameter tuning algorithms and seamlessly scales from laptop to distributed cluster with fault tolerance. RayTune leverages [Ray](https://ray.readthedocs.io/)'s Actor API to provide asynchronous parallel and distributed execution.
# 
# Ray 'Actors' are a simple and clean abstraction for replicating your Python classes across multiple workers and nodes. Each hyperparameter evaluation is asynchronously executed on a separate Ray actor and reports intermediate training progress back to RayTune. Upon reporting, RayTune then uses this information to performs actions such as early termination, re-prioritization, or checkpointing.

# In[1]:
#!/usr/bin/env python
# coding: utf-8

# # Ax Service API with RayTune on PyTorch CNN
# 
# Ax integrates easily with different scheduling frameworks and distributed training frameworks. In this example, Ax-driven optimization is executed in a distributed fashion using [RayTune](https://ray.readthedocs.io/en/latest/tune.html). 
# 
# RayTune is a scalable framework for hyperparameter tuning that provides many state-of-the-art hyperparameter tuning algorithms and seamlessly scales from laptop to distributed cluster with fault tolerance. RayTune leverages [Ray](https://ray.readthedocs.io/)'s Actor API to provide asynchronous parallel and distributed execution.
# 
# Ray 'Actors' are a simple and clean abstraction for replicating your Python classes across multiple workers and nodes. Each hyperparameter evaluation is asynchronously executed on a separate Ray actor and reports intermediate training progress back to RayTune. Upon reporting, RayTune then uses this information to performs actions such as early termination, re-prioritization, or checkpointing.

# In[1]:

import logging
from ray import tune
from ray.tune.search.ax import AxSearch
logger = logging.getLogger(tune.__name__)  
logger.setLevel(level=logging.CRITICAL)  # Reduce the number of Ray warnings that are not relevant here.

# In[2]:

import torch
import numpy as np

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.ax_client import AxClient
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate

init_notebook_plotting()

# ## 1. Initialize client
# We specify `enforce_sequential_optimization` as False, because Ray runs many trials in parallel. With the sequential optimization enforcement, `AxClient` would expect the first few trials to be completed with data before generating more trials.
# 
# When high parallelism is not required, it is best to enforce sequential optimization, as it allows for achieving optimal results in fewer (but sequential) trials. In cases where parallelism is important, such as with distributed training using Ray, we choose to forego minimizing resource utilization and run more trials in parallel.

# In[3]:

ax = AxClient(enforce_sequential_optimization=False)

# ## 2. Set up experiment
# Here we set up the search space and specify the objective; refer to the Ax API tutorials for more detail.

# In[4]:

ax.create_experiment(
    name="mnist_experiment",
    parameters=[
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
    ],
    objective_name="mean_accuracy",
)

# ## 3. Define how to evaluate trials
# Since we use the Ax Service API here, we evaluate the parameterizations that Ax suggests, using RayTune. The evaluation function follows its usual pattern, taking in a parameterization and outputting an objective value. For detail on evaluation functions, see [Trial Evaluation](https://ax.dev/docs/runner.html). 

# In[5]:

def train_evaluate(parameterization):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, valid_loader, test_loader = load_mnist(data_path='~/.data')
    net = train(train_loader=train_loader, parameters=parameterization, dtype=torch.float, device=device)

# ## 4. Run optimization
# Execute the Ax optimization and trial evaluation in RayTune using [AxSearch algorithm](https://ray.readthedocs.io/en/latest/tune-searchalg.html#ax-search):

# In[6]:

tune.run(
    train_evaluate, 
    num_samples=30, 
    search_alg=AxSearch(ax),  # Note that the argument here is the `AxClient`.
    verbose=0,  # Set this level to 1 to see status updates and to 2 to also see trial results.
    # To use GPU, specify: resources_per_trial={"gpu": 1}.
)

# ## 5. Retrieve the optimization results

# In[7]:

best_parameters, values = ax.get_best_parameters()
best_parameters

# In[8]:

means, covariances = values
means

# ## 6. Plot the response surface and optimization trace

# In[9]:

render(
    plot_contour(
        model=ax.generation_strategy.model, param_x='lr', param_y='momentum', metric_name='mean_accuracy'
    )
)

# In[10]:

# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean * 100 for trial in ax.experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Accuracy",
)
render(best_objective_plot)

)
render(best_objective_plot)
mgrange1998 commented 3 months ago

Hi, thank you for opening the issue and letting us know about the tutorial. This tutorial is created for Ax version 0.1.9, so for the time being it is acceptable that it won't work for newer versions of Ax. I'll close this issue for now, and also address ray's issue.

KantiCodes commented 3 months ago

Thanks for getting back to me. Makes sense. I don't think Ray is compatible with Latest AX and it seems like they stopped maintaining it :(