google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

XlaRuntimeError when create barkour_v0_joystick env #431

Closed Leon-LXA closed 6 months ago

Leon-LXA commented 7 months ago

I have cloned the repo and am trying to train the quadruped env locally. However, when I run barkour_env = envs.create('barkour_v0_joystick', backend='generalized', kick_vel=0.05 ), there is an error:

XlaRuntimeError Traceback (most recent call last) Cell In[3], line 1 ----> 1 barkour_env = envs.create('barkour_v0_joystick', 2 backend='generalized', kick_vel=0.05 3 ) 4 state = jax.jit(barkour_env.reset)(rng=jax.random.PRNGKey(seed=0)) 6 HTML(html.render(barkour_env.sys, [state.pipeline_state]))

File ~/brax/brax/envs/init.py:97, in create(env_name, episode_length, action_repeat, auto_reset, batch_size, kwargs) 76 def create( 77 env_name: str, 78 episode_length: int = 1000, (...) 82 kwargs, 83 ) -> Env: 84 """Creates an environment from the registry. 85 86 Args: (...) 95 env: an environment 96 """ ---> 97 env = _envsenv_name 99 if episode_length is not None: 100 env = training.EpisodeWrapper(env, episode_length, action_repeat)

File ~/brax/brax/experimental/barkour_v0/barkour_joystick.py:124, in Barkourv0.init(self, obs_noise, kick_vel, action_scale, backend, debug, **kwargs) 122 self._obs_noise = obs_noise 123 self._kick_vel = kick_vel --> 124 self._default_ap_pose = sys.init_q[7:19] 125 self.reward_config = get_config() 126 self.torso_idx = self.sys.link_names.index('chassis')

File ~/brax/env/lib/python3.8/site-packages/jax/_src/array.py:317, in ArrayImpl.getitem(self, idx) 315 return lax_numpy._rewriting_take(self, idx) 316 else: --> 317 return lax_numpy._rewriting_take(self, idx)

File ~/brax/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4128, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value) 4119 def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, 4120 mode=None, fill_value=None): 4121 # Computes arr[idx]. (...) 4125 # For simplicity of generated primitives, we call lax.dynamic_slice in the 4126 # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. -> 4128 if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: 4129 return result 4131 # TODO(mattjj,dougalm): expand dynamic shape indexing support

File ~/brax/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4113, in _attempt_rewriting_take_via_slice(arr, idx, mode) 4111 if len(start_indices) > 1: 4112 start_indices = util.promote_dtypes(*start_indices) -> 4113 arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes) 4114 if int_indices: 4115 arr = lax.squeeze(arr, tuple(int_indices))

File ~/brax/env/lib/python3.8/site-packages/jax/_src/lax/slicing.py:111, in dynamic_slice(operand, start_indices, slice_sizes) 109 dynamic_sizes = [] 110 static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore --> 111 return dynamic_slice_p.bind(operand, start_indices, dynamic_sizes, 112 slice_sizes=tuple(static_sizes))

File ~/brax/env/lib/python3.8/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, *params) 377 def bind(self, args, **params): 378 assert (not config.jax_enable_checks or 379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 380 return self.bind_with_trace(find_top_trace(args), args, params)

File ~/brax/env/lib/python3.8/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params) 382 def bind_with_trace(self, trace, args, params): --> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params) 384 return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/brax/env/lib/python3.8/site-packages/jax/_src/core.py:815, in EvalTrace.process_primitive(self, primitive, tracers, params) 814 def process_primitive(self, primitive, tracers, params): --> 815 return primitive.impl(*tracers, **params)

File ~/brax/env/lib/python3.8/site-packages/jax/_src/dispatch.py:144, in apply_primitive(prim, *args, *params) 140 msg = pjit._device_assignment_mismatch_error( 141 prim.name, fails, args, 'jit', arg_names) 142 raise ValueError(msg) from None --> 144 return compiled_fun(args)

File ~/brax/env/lib/python3.8/site-packages/jax/_src/dispatch.py:227, in xla_primitive_callable..(*args, kw) 223 compiled = _xla_callable_uncached( 224 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals, 225 orig_in_shardings) 226 if not prim.multiple_results: --> 227 return lambda *args, *kw: compiled(args, kw)[0] 228 else: 229 return compiled

File ~/brax/env/lib/python3.8/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, kwargs) 311 @wraps(func) 312 def wrapper(args, kwargs): 313 with TraceAnnotation(name, decorator_kwargs): --> 314 return func(args, kwargs) 315 return wrapper

File ~/brax/env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py:1349, in ExecuteReplicated.call(self, *args) 1344 self._handle_token_bufs( 1345 results.disassemble_prefix_into_single_device_arrays( 1346 len(self.ordered_effects)), 1347 results.consume_token()) 1348 else: -> 1349 results = self.xla_executable.execute_sharded(input_bufs) 1350 if dispatch.needs_check_special(): 1351 out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid; current tracing scope: dynamic-slice; current profiling annotation: XlaModule:#hlo_module=jit_dynamic_slice,program_id=2#.

btaba commented 6 months ago

Hi @Leon-LXA , from the traceback, it looks like you don't have a clean install of jax on your device. This link might help https://jax.readthedocs.io/en/latest/installation.html#installing-jax