google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

ptxas version issue #427

Closed zohimchandani closed 1 year ago

zohimchandani commented 1 year ago

Running the following code snippet and getting an error:

import jax
from jaxopt import GradientDescent

jax.devices('cpu')

def f(x): 
    return x**2

opt = GradientDescent(fun=f, stepsize=0.1, maxiter = 100, verbose = True, value_and_grad=False)

opt.run([3])
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 12
      7     return x**2
     10 opt = GradientDescent(fun=f, stepsize=0.1, maxiter = 100, verbose = True, value_and_grad=False)
---> 12 opt.run([3])

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/base.py:255, in IterativeSolver.run(self, init_params, *args, **kwargs)
    248   decorator = idf.custom_root(
    249       self.optimality_fun,
    250       has_aux=True,
    251       solve=self.implicit_diff_solve,
    252       reference_signature=reference_signature)
    253   run = decorator(run)
--> 255 return run(init_params, *args, **kwargs)

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(*args, **kwargs)
    249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

    [... skipping hidden 5 frame]

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_flat(*flat_args)
    204 @jax.custom_vjp
...
    469 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    470 # to take in `host_callbacks`
--> 471 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas/nvlink version string: INTERNAL: Couldn't invoke ptxas --version
nouiz commented 1 year ago

How did you install JAX? pip command line? What is your environment/container? XLA need ptxas and it doesn't find it. So maybe it isn't installed, or it is installed at a place that XLA doesn't find.

zohimchandani commented 1 year ago

Fix below

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html