dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
63 stars 25 forks source link

Fix installation for GPU on Windows #110

Closed KonradDanielewski closed 7 months ago

KonradDanielewski commented 7 months ago

cudatoolkit and cudnn are now installed from proper channels - newer versions but since they are compatible it doesn't hurt to have them. pytables is also added so the extra step from installation can be skipped. Now installation is fully complete just by running env creation from the file

KonradDanielewski commented 7 months ago

I suspected from the begining it's not related to Windows 11 in any way, since it's basically just a front end refresh + some minor scheduler features. This works both on Windows 10 and 11, tested both today.

calebweinreb commented 7 months ago

Amazing! Thanks for working on this. I'll test this on some our machines and then merge.

calebweinreb commented 7 months ago

Hmm I tried running

conda env create -f conda_envs\environment.win64_gpu.yml

using your fork on Windows 10 and got

XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc(64): 'status'

when importing keypoint_moseq.

@wingillis does this happen on your Windows 11 machine?

KonradDanielewski commented 7 months ago

I think this was the same issue that Lochlan had during the course (had the same one on Windows 11 at the lab workstation). Updating the GPU driver should fix this

calebweinreb commented 7 months ago

Woohoo that worked!! Now just need to test on our own Windows 11.

tsievert commented 7 months ago

I chatted with @calebweinreb on the moseq Slack about troubles installing kpms on Win11, and he pointed me towards your fix as a potential solution. I tried the installation with your env, but the installation fails with the following error

(base) PS C:\WINDOWS\system32\keypoint-moseq> conda env create -f conda_envs\environment.win64_gpu.yml
Collecting package metadata (repodata.json): done
Solving environment: done

==> WARNING: A newer version of conda exists. <==
  current version: 23.5.0
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.10.0

Downloading and Extracting Packages

Preparing transaction: done
Verifying transaction: done
Executing transaction: \ "By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html"

| "By downloading and using the cuDNN conda packages, you accept the terms and conditions of the NVIDIA cuDNN EULA - https://docs.nvidia.com/deeplearning/cudnn/sla/index.html"

done
ERROR conda.core.link:_execute(952): An error occurred while installing package 'conda-forge::cudatoolkit-11.8.0-h09e9e62_12'.
Rolling back transaction: done

LinkError: post-link script failed for package conda-forge::cudatoolkit-11.8.0-h09e9e62_12
location of failed script: C:\Users\ULiege\anaconda3\envs\keypoint_moseq\Scripts\.cudatoolkit-post-link.bat
==> script messages <==
"By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html"

==> script output <==
stdout:
stderr: 'chcp' is not recognized as an internal or external command,
operable program or batch file.
'chcp' is not recognized as an internal or external command,
operable program or batch file.
'chcp' is not recognized as an internal or external command,
operable program or batch file.

return code: 1

()

Please let me know if you need any additional info or want me to try other things!

KonradDanielewski commented 7 months ago

Hi @tsievert. Definitely run this from the Prompt not Powershell. Another thing is, move it outside of your system32 when you're installing and of course run the Prompt as administrator.

stderr: 'chcp' is not recognized as an internal or external command,
operable program or batch file.

etc.

It's either because of powershell or (more likely) lack of privileges on top of running stuff from system32

wingillis commented 7 months ago

Hmm I tried running

conda env create -f conda_envs\environment.win64_gpu.yml

using your fork on Windows 10 and got

XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc(64): 'status'

when importing keypoint_moseq.

@wingillis does this happen on your Windows 11 machine?

I didn't get this error when I tried running import keypoint_moseq on my machine - importing worked just fine. I have the latest nvidia drivers, and ran this from the anaconda-supplied command prompt.

tsievert commented 7 months ago

@KonradDanielewski I tried three different things now:

  1. Anaconda prompt and new file location 1.1 Files move to C:Program Fileskeypoint-moseq 1.2 Used Anaconda prompt as admin (because adding Anaconda to PATH is discouraged in the latest versions, according to some Stackoverflow conversation) 1.3 Same error as above

  2. cmd.exe and new file location 2.1 Add Anaconda to PATH 2.2 files in C:Program Files\keypoint-moseq 2.3 cmd.exe as admin 2.4 Exactly the same error again

  3. Follow some random Stackoverflow advice 3.1 Add ;%SystemRoot%\system32;%SystemRoot%;%SystemRoot%\System32\Wbem; to PATH 3.2 repeat what I did in 2. 3.3 Same error again

tsievert commented 7 months ago

I think I figured it out! I restarted my machine, and the install went through, but I couldn't activate the environment. Since the error mentioned system32 and seemed to list my PATH, I removed ;%SystemRoot%\system32;%SystemRoot%;%SystemRoot%\System32\Wbem; again and now I can activate the environment. Next step is to check if moseq actually runs

EDIT: It does not run unfortunately

import keypoint_moseq as kpms

project_dir = 'demo_project'
config = lambda: kpms.load_config(project_dir)

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 1
----> 1 import keypoint_moseq as kpms
      3 project_dir = 'demo_project'
      4 config = lambda: kpms.load_config(project_dir)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\keypoint_moseq\__init__.py:11
      7 import warnings
      9 warnings.formatwarning = lambda msg, *a: str(msg)
---> 11 from .io import *
     12 from .viz import *
     13 from .util import *

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\keypoint_moseq\io.py:17
     14 from ndx_pose import PoseEstimation
     15 from itertools import islice
---> 17 from keypoint_moseq.util import list_files_with_exts, check_nan_proportions
     18 from jax_moseq.utils import get_frequencies, unbatch
     21 def _build_yaml(sections, comments):

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\keypoint_moseq\util.py:14
     12 from sklearn.neighbors import NearestNeighbors
     13 from scipy.spatial.distance import pdist, squareform
---> 14 from jax_moseq.models.keypoint_slds import inverse_rigid_transform
     15 from jax_moseq.utils import get_frequencies, batch
     17 na = jnp.newaxis

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\__init__.py:2
      1 import jax
----> 2 from . import models
      3 from . import utils
      4 from . import _version

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\__init__.py:1
----> 1 from . import arhmm
      2 from . import slds
      3 from . import keypoint_slds

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\__init__.py:1
----> 1 from jax_moseq.models.arhmm.initialize import *
      2 from jax_moseq.models.arhmm.gibbs import *
      3 from jax_moseq.models.arhmm.log_prob import *

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:7
      5 from jax_moseq.utils import device_put_as_scalar, check_precision
      6 from jax_moseq.utils.transitions import init_hdp_transitions
----> 7 from jax_moseq.utils.distributions import sample_mniw
      9 from jax_moseq.models.arhmm.gibbs import resample_discrete_stateseqs
     11 na = jnp.newaxis

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:3
      1 import jax, jax.numpy as jnp, jax.random as jr
      2 import tensorflow_probability.substrates.jax.distributions as tfd
----> 3 from dynamax.hidden_markov_model.inference import hmm_posterior_sample
      4 from jax_moseq.utils import convert_data_precision
      6 na = jnp.newaxis

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\dynamax\hidden_markov_model\__init__.py:1
----> 1 from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMInitialState, HMMTransitions, HMMParameterSet, HMMPropertySet
      2 from dynamax.hidden_markov_model.models.arhmm import LinearAutoregressiveHMM
      3 from dynamax.hidden_markov_model.models.bernoulli_hmm import BernoulliHMM

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\dynamax\hidden_markov_model\models\abstractions.py:2
      1 from abc import abstractmethod, ABC
----> 2 from dynamax.ssm import SSM
      3 from dynamax.types import Scalar
      4 from dynamax.parameters import to_unconstrained, from_unconstrained

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\dynamax\ssm.py:18
     16 from dynamax.parameters import ParameterSet, PropertySet
     17 from dynamax.types import PRNGKey, Scalar
---> 18 from dynamax.utils.optimize import run_sgd
     19 from dynamax.utils.utils import ensure_array_has_batch_dim
     22 class Posterior(Protocol):

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\dynamax\utils\optimize.py:28
     17     for idx in range(0, n_data, batch_size):
     18         yield tree_map(lambda x: x[perm[idx:min(idx + batch_size, n_data)]], dataset)
     21 def run_sgd(loss_fn,
     22             params,
     23             dataset,
     24             optimizer=optax.adam(1e-3),
     25             batch_size=1,
     26             num_epochs=50,
     27             shuffle=False,
---> 28             key=jr.PRNGKey(0)):
     29     """
     30     Note that batch_emissions is initially of shape (N,T)
     31     where N is the number of independent sequences and
   (...)
     48         losses: Output of loss_fn stored at each step.
     49     """
     50     opt_state = optimizer.init(params)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\random.py:132, in PRNGKey(seed)
    129 if np.ndim(seed):
    130   raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
    131                   f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 132 key = prng.seed_with_impl(impl, seed)
    133 return _return_prng_keys(True, key)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:267, in seed_with_impl(impl, seed)
    266 def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
--> 267   return random_seed(seed, impl=impl)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:593, in random_seed(seeds, impl)
    591 else:
    592   seeds_arr = jnp.asarray(seeds)
--> 593 return random_seed_p.bind(seeds_arr, impl=impl)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:328, in Primitive.bind(self, *args, **params)
    325 def bind(self, *args, **params):
    326   assert (not config.jax_enable_checks or
    327           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 328   return self.bind_with_trace(find_top_trace(args), args, params)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:331, in Primitive.bind_with_trace(self, trace, args, params)
    330 def bind_with_trace(self, trace, args, params):
--> 331   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    332   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:698, in EvalTrace.process_primitive(self, primitive, tracers, params)
    697 def process_primitive(self, primitive, tracers, params):
--> 698   return primitive.impl(*tracers, **params)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:605, in random_seed_impl(seeds, impl)
    603 @random_seed_p.def_impl
    604 def random_seed_impl(seeds, *, impl):
--> 605   base_arr = random_seed_impl_base(seeds, impl=impl)
    606   return PRNGKeyArray(impl, base_arr)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:610, in random_seed_impl_base(seeds, impl)
    608 def random_seed_impl_base(seeds, *, impl):
    609   seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 610   return seed(seeds)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:845, in threefry_seed(seed)
    842   raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
    843 convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
    844 k1 = convert(
--> 845     lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
    846 with jax.numpy_dtype_promotion('standard'):
    847   # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
    848   # inputs. We should avoid this.
    849   k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\lax\lax.py:515, in shift_right_logical(x, y)
    513 def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
    514   r"""Elementwise logical right shift: :math:`x \gg y`."""
--> 515   return shift_right_logical_p.bind(x, y)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:328, in Primitive.bind(self, *args, **params)
    325 def bind(self, *args, **params):
    326   assert (not config.jax_enable_checks or
    327           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 328   return self.bind_with_trace(find_top_trace(args), args, params)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:331, in Primitive.bind_with_trace(self, trace, args, params)
    330 def bind_with_trace(self, trace, args, params):
--> 331   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    332   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\core.py:698, in EvalTrace.process_primitive(self, primitive, tracers, params)
    697 def process_primitive(self, primitive, tracers, params):
--> 698   return primitive.impl(*tracers, **params)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:112, in apply_primitive(prim, *args, **params)
    110 def apply_primitive(prim, *args, **params):
    111   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 112   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
    113                                         **params)
    114   return compiled_fun(*args)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\util.py:222, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    220   return f(*args, **kwargs)
    221 else:
--> 222   return cached(config._trace_context(), *args, **kwargs)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\util.py:215, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    213 @functools.lru_cache(max_size)
    214 def cached(_, *args, **kwargs):
--> 215   return f(*args, **kwargs)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:196, in xla_primitive_callable(prim, *arg_specs, **params)
    194   else:
    195     return out,
--> 196 compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
    197                                   prim.name, donated_invars, False, *arg_specs)
    198 if not prim.multiple_results:
    199   return lambda *args, **kw: compiled(*args, **kw)[0]

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:342, in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    340   return computation.compile(_allow_propagation_to_outputs=True).unsafe_call
    341 else:
--> 342   return lower_xla_callable(fun, device, backend, name, donated_invars, False,
    343                             keep_unused, *arg_specs).compile().unsafe_call

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:978, in XlaComputation.compile(self)
    975     self._executable = XlaCompiledComputation.from_trivial_jaxpr(
    976         **self.compile_args)
    977   else:
--> 978     self._executable = XlaCompiledComputation.from_xla_computation(
    979         self.name, self._hlo, self._in_type, self._out_type,
    980         **self.compile_args)
    982 return self._executable

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:1136, in XlaCompiledComputation.from_xla_computation(name, xla_computation, in_type, out_type, nreps, device, backend, tuple_args, in_avals, out_avals, has_unordered_effects, ordered_effects, kept_var_idx, keepalive, host_callbacks)
   1133 options.parameter_is_tupled_arguments = tuple_args
   1134 with log_elapsed_time(f"Finished XLA compilation of {name} "
   1135                       "in {elapsed_time} sec"):
-> 1136   compiled = compile_or_get_cached(backend, xla_computation, options,
   1137                                    host_callbacks)
   1138 buffer_counts = get_buffer_counts(out_avals, ordered_effects,
   1139                                   has_unordered_effects)
   1140 execute = _execute_compiled if nreps == 1 else _execute_replicated

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:1054, in compile_or_get_cached(backend, computation, compile_options, host_callbacks)
   1050     _cache_write(serialized_computation, module_name,  compile_options,
   1051                  backend, compiled)
   1052     return compiled
-> 1054 return backend_compile(backend, serialized_computation, compile_options,
   1055                        host_callbacks)

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\profiler.py:313, in annotate_function.<locals>.wrapper(*args, **kwargs)
    310 @wraps(func)
    311 def wrapper(*args, **kwargs):
    312   with TraceAnnotation(name, **decorator_kwargs):
--> 313     return func(*args, **kwargs)
    314   return wrapper

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:994, in backend_compile(backend, built_c, options, host_callbacks)
    989   return backend.compile(built_c, compile_options=options,
    990                          host_callbacks=host_callbacks)
    991 # Some backends don't have `host_callbacks` option yet
    992 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    993 # to take in `host_callbacks`
--> 994 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc(64): 'status'
´´´
KonradDanielewski commented 7 months ago

EDIT: It does not run unfortunately

Like mentioned above, this error is because you need to update your GPU driver.

tsievert commented 7 months ago

Indeed, that was the last missing step. Thanks a lot!