google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

`NameError: unbound axis name` when running example in documentation #436

Open paupereira opened 1 year ago

paupereira commented 1 year ago

I'm running the notebook custom_loop_pmap_example.ipynb from this example in the documentation.

When option use_pmap=True it produces the following error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[39], line 9
      6 print("linesearch (ignored if `stepsize` > 0):", LINESEARCH)
      7 print()
----> 9 errors, step_times, compile_time = run()
     10 print('Average speed-up (ignoring compile):',
     11       round((step_times['without_pmap'] [/](https://vscode-remote+ssh-002dremote-002bwestbun.vscode-resource.vscode-cdn.net/) step_times['with_pmap']).mean(), 2))

Cell In[38], line 15, in run()
     13 exp_name: str = f"{'with' if use_pmap else 'without'}_pmap"
     14 print(exp_name)
---> 15 _errors, _step_times, _compile_time = fit(data=data,
     16                                           init_params=init_params,
     17                                           stepsize=STEPSIZE,
     18                                           linesearch=LINESEARCH,
     19                                           use_pmap=use_pmap)
     20 errors[exp_name] = _errors
     21 step_times[exp_name] = _step_times

Cell In[37], line 49, in fit(data, init_params, stepsize, linesearch, use_pmap)
     46   update = jax.jit(solver.update)
     48 # Initialize solver state.
---> 49 state = solver.init_state(init_params, data=data)
     50 params = init_params
     51 # If using `pmap` for data-parallel training, model parameters are typically
     52 # replicated across devices.

File [~/jax/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py:284](https://vscode-remote+ssh-002dremote-002bwestbun.vscode-resource.vscode-cdn.net/home/ubuntu/workplace/learn/jaxopt/docs/notebooks/distributed/~/jax/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py:284), in LBFGS.init_state(self, init_params, *args, **kwargs)
    275   dtype = tree_single_dtype(init_params)
    276   state_kwargs = dict(
    277     s_history=init_history(init_params, self.history_size),
    278     y_history=init_history(init_params, self.history_size),
   (...)
    282     stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
    283   )
--> 284 (value, aux), grad = self._value_and_grad_with_aux(init_params, *args, **kwargs)
    285 return LbfgsState(value=value,
    286                   grad=grad,
    287                   error=jnp.asarray(jnp.inf),
    288                   **state_kwargs,
    289                   aux=aux,
    290                   failed_linesearch=jnp.asarray(False))

File [~/jax/lib/python3.10/site-packages/jaxopt/_src/base.py:62](https://vscode-remote+ssh-002dremote-002bwestbun.vscode-resource.vscode-cdn.net/home/ubuntu/workplace/learn/jaxopt/docs/notebooks/distributed/~/jax/lib/python3.10/site-packages/jaxopt/_src/base.py:62), in _add_aux_to_value_and_grad..value_and_grad_with_aux(*a, **kw)
     61 def value_and_grad_with_aux(*a, **kw):
---> 62   v, g = value_and_grad(*a, **kw)
     63   return (v, None), g

Cell In[35], line 6, in pmean..wrapper(*args, **kwargs)
      4 @functools.wraps(fun)
      5 def wrapper(*args, **kwargs):
----> 6   return jax.tree_map(maybe_pmean, fun(*args, **kwargs))

File [~/jax/lib/python3.10/site-packages/jax/_src/tree_util.py:210](https://vscode-remote+ssh-002dremote-002bwestbun.vscode-resource.vscode-cdn.net/home/ubuntu/workplace/learn/jaxopt/docs/notebooks/distributed/~/jax/lib/python3.10/site-packages/jax/_src/tree_util.py:210), in tree_map(f, tree, is_leaf, *rest)
...
-> 2530 raise NameError(
   2531     f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
   2532     f'by pmap) are available to collective operations: {named_axes}')

NameError: unbound axis name: b. The following axis names (e.g. defined by pmap) are available to collective operations: []

This are the versions of jax and jaxopt in my system:

jax                          0.4.11
jaxlib                       0.4.7+cuda11.cudnn86
jaxopt                       0.7
egg5154 commented 1 year ago

Hello, I also encountered this bug when running the example. Have you found the way to handle it?

paupereira commented 1 year ago

@egg5154 I didn't solve it. But now pmap is on its deprecation path so the point is moot. Try using sharding instead (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).