YukunXia / Carla_iLQR_MPC

Implementation of the real-time MPC based on iLQR in Carla simulator
MIT License
264 stars 57 forks source link

importError from jaxlib in running ilqr_jax_MPC.py #5

Closed lzhhh93 closed 2 years ago

lzhhh93 commented 2 years ago

Hi @Tanman1234 @YukunXia, thanks for the discussion here and thank @YukunXia for the code.

I also have problem with jaxlib after installing jax 0.1.68, specifically a ImportError: cannot import name 'pytree' from 'jaxlib'. I am working on ubuntu 18.04.

Here is what i have done for installation of jax 0.1.68 by following https://github.com/google/jax/tree/jaxlib-v0.1.68:

     pip install --upgrade pip
     sudo ln -s /path/to/cuda /usr/local/cuda-11.1
     pip install --upgrade jax==0.1.68 jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

-result of installation:

Successfully installed jax-0.1.68 jaxlib-0.1.67+cuda111

-run lqr_jax_MPC.py and get ImportError:

Traceback (most recent call last):
  File "/home/control/Documents/Carla projects/Carla_iLQR_MPC/MPC/ilqr_jax_MPC.py", line 1, in <module>
    from jax import jit, jacfwd, jacrev, hessian, lax
  File "/home/control/.local/lib/python3.7/site-packages/jax/__init__.py", line 16, in <module>
    from .api import (
  File "/home/control/.local/lib/python3.7/site-packages/jax/api.py", line 38, in <module>
    from . import core
  File "/home/control/.local/lib/python3.7/site-packages/jax/core.py", line 30, in <module>
    from . import dtypes
  File "/home/control/.local/lib/python3.7/site-packages/jax/dtypes.py", line 31, in <module>
    from .lib import xla_client
  File "/home/control/.local/lib/python3.7/site-packages/jax/lib/__init__.py", line 51, in <module>
    from jaxlib import pytree
ImportError: cannot import name 'pytree' from 'jaxlib' (/home/control/.local/lib/python3.7/site-packages/jaxlib/__init__.py)

For your info, Cuda version 11.5 is installed. cudnn-linux-x86_64-8.3.1.22_cuda11.5 is downloaded and the symlinks of its files in /include and /lib are copied in usr/local/cuda/include and usr/local/cuda/lib with:

  cd folder/extracted/cdnn_contents
  sudo cp include/cudnn.h /usr/local/cuda/include
  sudo cp lib/libcudnn* /usr/local/cuda/lib64
  sudo chmod a+r /usr/local/cuda/lib64/libcudnn*

I have also tried to install jaxlib from source but not succeed, it came up with this problem: https://stackoverflow.com/questions/70324228/how-to-deal-with-error-infinite-symlink-expansion-detected-in-building-jax-from

Also tried to install jaxlib and jax==0.2.14, 0.2.16 with pip and the lqr_jax_MPC.py still shows different Errors.

Do you know how to deal with these pitfalls?

Many thanks.

YukunXia commented 2 years ago

Hi @lzhhh93 , thanks for your trials! The code didn't consider CUDA, so you may not need the related packages to run the code. Also, I noticed that you submitted an issue in the Jax repo, and just to let you know, my code was run with jaxlib 0.1.47.

lzhhh93 commented 2 years ago

Hi @lzhhh93 , thanks for your trials! The code didn't consider CUDA, so you may not need the related packages to run the code. Also, I noticed that you submitted an issue in the Jax repo, and just to let you know, my code was run with jaxlib 0.1.47.

Hi @YukunXia, thanks for a quick replay. I have tried to reinstall jax and jaxlib without cuda again, but it shows:

python -m pip install jax==0.1.68 jaxlib==0.1.47
Defaulting to user installation because normal site-packages is not writeable
Collecting jax==0.1.68
  Using cached jax-0.1.68-py3-none-any.whl
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.47 (from versions: 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75)
ERROR: No matching distribution found for jaxlib==0.1.47

Maybe it is not available any more. It's fine, because a jax expert from JAX repo tells me kindly to use jaxlib==0.1.50 would be ok (https://github.com/google/jax/discussions/8921#discussioncomment-1799617) and it really works for me.

Hope this issue could help anyone. Thanks again for your kind response @YukunXia !

YukunXia commented 2 years ago

Hi @lzhhh93 Great that the code now works for you! If you figure out how to properly adapt the code to a newer version of Jax, feel free to submit a PR later.

lzhhh93 commented 2 years ago

@YukunXia I would be very happy to do that!

Could you also share which version of torch(with or without cuda) and tensorboard that you used for this project? the last .py i still couldn't run is system_id,py, it shows:

Traceback (most recent call last):
  File "/home/control/Documents/Carla projects/Carla_iLQR_MPC/SystemID/model_to_jax.py", line 8, in <module>
    from system_id import Net_v4
  File "/home/control/Documents/Carla projects/Carla_iLQR_MPC/SystemID/system_id.py", line 10, in <module>
    from torch.utils.tensorboard import SummaryWriter
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/tensorboard/__init__.py", line 4, in <module>
    LooseVersion = distutils.version.LooseVersion
AttributeError: module 'distutils' has no attribute 'version'

many thanks.

YukunXia commented 2 years ago

Pytorch is 1.5.0 with CUDA 10.2 tensorboard is 2.2.1


Besides, CUDA is not likely to be a big problem source, due to this hyperparam https://github.com/YukunXia/Carla_iLQR_MPC/blob/00378cf12979dfc4e828821e36a93e674142a0a6/SystemID/system_id.py#L23