Epistimio / orion

Asynchronous Distributed Hyperparameter Optimization.
https://orion.readthedocs.io
Other
287 stars 52 forks source link

Orion multi-task freezes when jax.random.truncated_normal is used with a loguniform hyperparameter #1117

Closed NeilGirdhar closed 1 year ago

NeilGirdhar commented 1 year ago

Orion freezes when this code is run:

from typing import Any

import jax.numpy as jnp
from jax.random import PRNGKey, truncated_normal
from orion.client import build_experiment

a = jnp.zeros(10)

def experiment_f(x: float) -> list[dict[str, Any]]:
    key = PRNGKey(123)
    print("DIE")  # This is printed thrice
    truncated_normal(key, -2.0, 2.0, (85, 85))
    print("DEAD")   # This is never printed
    return [{'name': 'loss', 'type': 'objective', 'value': 1.0}]

workers = 3
space = {'x': "loguniform(1e-4, 1e2)"}
experiment = build_experiment(name='blah', space=space,
                              max_trials=10, max_broken=1,
                              algorithm={"tpe": {"n_initial_points": 5}})
with experiment.tmp_executor("joblib", n_workers=workers):
    experiment.workon(experiment_f, workers)

Environment:

Possible solution Could be related to Jax's desire to Jit its random functions.

NeilGirdhar commented 1 year ago

See the linked Jax issue for more ideas about what may be going wrong.

NeilGirdhar commented 1 year ago

I solved my problem by putting an assertion in jax._src.core._initialize_jax_jit_thread_local_state. This allowed me to find the one point at which I was forcing premature initialization of Jax and remove it.