danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.38k stars 233 forks source link

Obtain World Model Predictions during Inference. #110

Open defrag-bambino opened 8 months ago

defrag-bambino commented 8 months ago

Hi,

how may I obtain the predictions of the World Model during Inference? I have tried this command in a simple inference loop, but it throws an error: agent.agent.wm.imagine(agent.policy, obs, 10)

Error & Stacktrace

```│ /home/fabian/Desktop/fpv/py/dreamerv3/inference.py:60 in main │ │ │ │ 57 │ act = {'action': act['action'][0], 'reset': obs['is_last'][0]} │ │ 58 │ │ │ 59 │ if i > 100: │ │ ❱ 60 │ agent.agent.wm.imagine(agent.policy, obs, 10) │ │ 61 │ │ 62 │ │ 63 │ │ │ │ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/agent.py:183 in imagine │ │ │ │ 180 │ │ 181 def imagine(self, policy, start, horizon): │ │ 182 │ first_cont = (1.0 - start['is_terminal']).astype(jnp.float32) │ │ ❱ 183 │ keys = list(self.rssm.initial(1).keys()) │ │ 184 │ start = {k: v for k, v in start.items() if k in keys} │ │ 185 │ start['action'] = policy(start) │ │ 186 │ def step(prev, _): │ │ │ │ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/nets.py:34 in initial │ │ │ │ 31 def initial(self, bs): │ │ 32 │ if self._classes: │ │ 33 │ state = dict( │ │ ❱ 34 │ │ deter=jnp.zeros([bs, self._deter], f32), │ │ 35 │ │ logit=jnp.zeros([bs, self._stoch, self._classes], f32), │ │ 36 │ │ stoch=jnp.zeros([bs, self._stoch, self._classes], f32)) │ │ 37 │ else: │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/num │ │ py/lax_numpy.py:2317 in zeros │ │ │ │ 2314 if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise │ │ 2315 dtypes.check_user_dtype_supported(dtype, "zeros") │ │ 2316 shape = canonicalize_shape(shape) │ │ ❱ 2317 return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_t │ │ 2318 │ │ 2319 @util.implements(np.ones) │ │ 2320 def ones(shape: Any, dtype: DTypeLike | None = None, *, │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │ │ /lax.py:1226 in full │ │ │ │ 1223 │ return dtype._rules.full(shape, fill_value, dtype) # type: igno │ │ 1224 weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) │ │ 1225 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) │ │ ❱ 1226 fill_value = _convert_element_type(fill_value, dtype, weak_type) │ │ 1227 out = broadcast(fill_value, shape) │ │ 1228 if sharding is not None: │ │ 1229 │ return array.make_array_from_callback(shape, sharding, lambda id │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │ │ /lax.py:560 in _convert_element_type │ │ │ │ 557 │ │ isinstance(core.get_aval(operand), core.ConcreteArray))): │ │ 558 │ return type_cast(Array, operand) │ │ 559 else: │ │ ❱ 560 │ return convert_element_type_p.bind(operand, new_dtype=new_dtype, │ │ 561 │ │ │ │ │ │ │ │ │ weak_type=bool(weak_type)) │ │ 562 │ │ 563 def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) - │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │ │ e.py:444 in bind │ │ │ │ 441 def bind(self, *args, **params): │ │ 442 │ assert (not config.enable_checks.value or │ │ 443 │ │ │ all(isinstance(arg, Tracer) or valid_jaxtype(arg) for ar │ │ ❱ 444 │ return self.bind_with_trace(find_top_trace(args), args, params) │ │ 445 │ │ 446 def bind_with_trace(self, trace, args, params): │ │ 447 │ out = trace.process_primitive(self, map(trace.full_raise, args), │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │ │ e.py:447 in bind_with_trace │ │ │ │ 444 │ return self.bind_with_trace(find_top_trace(args), args, params) │ │ 445 │ │ 446 def bind_with_trace(self, trace, args, params): │ │ ❱ 447 │ out = trace.process_primitive(self, map(trace.full_raise, args), │ │ 448 │ return map(full_lower, out) if self.multiple_results else full_l │ │ 449 │ │ 450 def def_impl(self, impl): │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │ │ e.py:935 in process_primitive │ │ │ │ 932 lift = sublift = pure │ │ 933 │ │ 934 def process_primitive(self, primitive, tracers, params): │ │ ❱ 935 │ return primitive.impl(*tracers, **params) │ │ 936 │ │ 937 def process_call(self, primitive, f, tracers, params): │ │ 938 │ return primitive.impl(f, *tracers, **params) │ │ │ │ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/dis │ │ patch.py:87 in apply_primitive │ │ │ │ 84 if xla_extension_version >= 218: │ │ 85 │ prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) │ │ 86 │ try: │ │ ❱ 87 │ outs = fun(*args) │ │ 88 │ finally: │ │ 89 │ lib.jax_jit.swap_thread_local_state_disable_jit(prev) │ │ 90 else: │ ╰─────────────────────────────────────────────────────────────────────────────╯ XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer: aval=ShapedArray(float32[]), dst_sharding=GSPMDSharding({replicated}) ```

wnnng commented 2 months ago

Have you made any progress with it? Has anyone tried doing it with the refactored code?