JuliaPy / Conda.jl

Conda managing Julia binary dependencies
Other
174 stars 57 forks source link

PyCall is not using the packages in Conda #234

Open karimn opened 1 year ago

karimn commented 1 year ago

Hi,

I'm new to Julia and PyCall. I'm trying to import the Python transformers package but I'm not having any success and it appears to be a problem with PyCall not using the correct packages. I'm using Julia 1.8.5.

I don't understand why when I pyimport("transformers") I get an error about jaxlib being the wrong version. I confirmed that v0.4.4 is actually installed by Conda.

Here are the steps I tried.

julia> import Pkg

julia> ENV["PYTHON"] = ""
""

julia> Pkg.build("Conda")
    Building Conda → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/e32a90da027ca45d84678b826fffd3110bb3fc90/build.log`

julia> Pkg.build("PyCall")
    Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/e32a90da027ca45d84678b826fffd3110bb3fc90/build.log`
    Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/62f417f6ad727987c755549e9cd88c46578da562/build.log`

julia> exit()
julia> import Conda

julia> import PyCall

julia> Conda.ROOTENV
"/home/karim/.julia/conda/3"

julia> PyCall.conda
true

julia> PyCall.python
"/home/karim/.julia/conda/3/bin/python"

julia> PyCall.pyprogramname
"/home/karim/.julia/conda/3/bin/python"
julia> Conda.add("transformers")
[ Info: Running `conda install -y transformers` in root environment
Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.

julia> PyCall.pyimport("transformers")
ERROR: PyError (PyImport_ImportModule) <class 'RuntimeError'>
RuntimeError('jaxlib is version 0.1.75, but this version of jax requires version >= 0.4.2.')
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/__init__.py", line 30, in <module>
    from . import dependency_versions_check
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/dependency_versions_check.py", line 17, in <module>
    from .utils.versions import require_version, require_version_core
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/utils/__init__.py", line 34, in <module>
    from .generic import (
  File "/home/karim/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 36, in <module>
    import jax.numpy as jnp
  File "/home/karim/.local/lib/python3.10/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/home/karim/.local/lib/python3.10/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 74, in <module>
    version = check_jaxlib_version(
  File "/home/karim/.local/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 63, in check_jaxlib_version
    raise RuntimeError(msg)

Stacktrace:
 [1] pyimport(name::String)
   @ PyCall ~/.julia/packages/PyCall/twYvK/src/PyCall.jl:558
 [2] top-level scope
   @ REPL[6]:1
deszoeke commented 1 year ago

I have a similar problem, and similar conda installation, except in my case at the final step I try to import numpy:

PyCall.pyimport("numpy")

and it hangs.