danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 218 forks source link

Shape error with custom Unity env #16

Closed defrag-bambino closed 1 year ago

defrag-bambino commented 1 year ago

Hello, first of all I would also like to thank you for publicly sharing your research's code.

I am currently trying to run DreamerV3 on my custom environment, which was build using Unity3D's ML-Agents and wrapped as a Gym. After having some issues with the shape of my action and observation space, which I think fixed now, I am still having some issues with the dimensions of checkpoints. The issue occurs right at the beginning of training, when the Agent prefills its train dataset and the first checkpoint is saved. The full output is

here ```shell python3 example.py [UnityMemory] Configuration Parameters - Can be set up in boot.config "memorysetup-bucket-allocator-granularity=16" "memorysetup-bucket-allocator-bucket-count=8" "memorysetup-bucket-allocator-block-size=4194304" "memorysetup-bucket-allocator-block-count=1" "memorysetup-main-allocator-block-size=16777216" "memorysetup-thread-allocator-block-size=16777216" "memorysetup-gfx-main-allocator-block-size=16777216" "memorysetup-gfx-thread-allocator-block-size=16777216" "memorysetup-cache-allocator-block-size=4194304" "memorysetup-typetree-allocator-block-size=2097152" "memorysetup-profiler-bucket-allocator-granularity=16" "memorysetup-profiler-bucket-allocator-bucket-count=8" "memorysetup-profiler-bucket-allocator-block-size=4194304" "memorysetup-profiler-bucket-allocator-block-count=1" "memorysetup-profiler-allocator-block-size=16777216" "memorysetup-profiler-editor-allocator-block-size=1048576" "memorysetup-temp-allocator-size-main=4194304" "memorysetup-job-temp-allocator-block-size=2097152" "memorysetup-job-temp-allocator-block-size-background=1048576" "memorysetup-job-temp-allocator-reduction-small-platforms=262144" "memorysetup-temp-allocator-size-background-worker=32768" "memorysetup-temp-allocator-size-job-worker=262144" "memorysetup-temp-allocator-size-preload-manager=262144" "memorysetup-temp-allocator-size-nav-mesh-worker=65536" "memorysetup-temp-allocator-size-audio-worker=65536" "memorysetup-temp-allocator-size-cloud-worker=32768" "memorysetup-temp-allocator-size-gfx=262144" [WARNING] The environment contains multiple observations. You must define allow_multiple_obs=True to receive them all. Otherwise, only the first visual observation (or vector observation ifthere are no visual observations) will be provided in the observation. /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32 logger.warn( Encoder CNN shapes: {} Encoder MLP shapes: {'image': (16,)} Decoder CNN shapes: {} Decoder MLP shapes: {'image': (16,)} JAX devices (1): [CpuDevice(id=0)] Policy devices: TFRT_CPU_0 Train devices: TFRT_CPU_0 Tracing train function. Optimizer model_opt has 16,451,344 variables. Optimizer actor_opt has 1,052,676 variables. Optimizer critic_opt has 1,181,439 variables. Logdir logdir/run1 Observation space: image Space(dtype=float32, shape=(16,), low=-inf, high=inf) reward Space(dtype=float32, shape=(), low=-inf, high=inf) is_first Space(dtype=bool, shape=(), low=False, high=True) is_last Space(dtype=bool, shape=(), low=False, high=True) is_terminal Space(dtype=bool, shape=(), low=False, high=True) Action space: action Space(dtype=float32, shape=(2,), low=-1.0, high=1.0) reset Space(dtype=bool, shape=(), low=False, high=True) Prefill train dataset. Episode has 61 steps and return 0.1. Episode has 55 steps and return 0.1. Episode has 73 steps and return 0.1. Episode has 55 steps and return 0.1. Episode has 74 steps and return 0.1. Episode has 96 steps and return 0.4. Episode has 56 steps and return 0.1. Episode has 50 steps and return 0.1. Episode has 50 steps and return 0.0. Episode has 45 steps and return 0.0. Episode has 62 steps and return 0.1. Episode has 76 steps and return 0.2. Episode has 53 steps and return 0.1. Episode has 57 steps and return 0.1. Episode has 84 steps and return 0.2. Saved chunk: 20230224T123638F338048-7fz2YQGpaWRhCvMc8sKBIS-4NtbeiuY5nlHsbS33AqTMd-1024.npz Episode has 69 steps and return 0.1. ──────────────────────────────────────────────────────────────────────────────────────────────────── Step 1100 ──────────────────────────────────────────────────────────────────────────────────────────────────── episode/length 69 / episode/score 0.13 / episode/sum_abs_reward 0.13 / episode/reward_rate 0 Creating new TensorBoard event file writer. Did not find any checkpoint. Writing checkpoint: logdir/run1/checkpoint.ckpt Start training loop. Saved chunk: 20230224T123818F856929-4NtbeiuY5nlHsbS33AqTMd-0000000000000000000000-76.npz Wrote checkpoint: logdir/run1/checkpoint.ckpt Error writing summary: stats/policy_image Tracing policy function. ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/fabian/Desktop/dreamerv3/example.py:53 in │ │ │ │ 50 │ │ 51 │ │ 52 if __name__ == '__main__': │ │ ❱ 53 main() │ │ 54 │ │ │ │ /home/fabian/Desktop/dreamerv3/example.py:48 in main │ │ │ │ 45 args = embodied.Config( │ │ 46 │ **config.run, logdir=config.logdir, │ │ 47 │ batch_steps=config.batch_size * config.batch_length) │ │ ❱ 48 embodied.run.train(agent, env, replay, logger, args) │ │ 49 # embodied.run.eval_only(agent, env, logger, args) │ │ 50 │ │ 51 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:108 in train │ │ │ │ 105 policy = lambda *args: agent.policy( │ │ 106 │ *args, mode='explore' if should_expl(step) else 'train') │ │ 107 while step < args.steps: │ │ ❱ 108 │ driver(policy, steps=100) │ │ 109 │ if should_save(step): │ │ 110 │ checkpoint.save() │ │ 111 logger.write() │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:42 in __call__ │ │ │ │ 39 def __call__(self, policy, steps=0, episodes=0): │ │ 40 │ step, episode = 0, 0 │ │ 41 │ while step < steps or episode < episodes: │ │ ❱ 42 │ step, episode = self._step(policy, step, episode) │ │ 43 │ │ 44 def _step(self, policy, step, episode): │ │ 45 │ assert all(len(x) == len(self._env) for x in self._acts.values()) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:50 in _step │ │ │ │ 47 │ obs = self._env.step(acts) │ │ 48 │ obs = {k: convert(v) for k, v in obs.items()} │ │ 49 │ assert all(len(x) == len(self._env) for x in obs.values()), obs │ │ ❱ 50 │ acts, self._state = policy(obs, self._state, **self._kwargs) │ │ 51 │ acts = {k: convert(v) for k, v in acts.items()} │ │ 52 │ if obs['is_last'].any(): │ │ 53 │ mask = 1 - obs['is_last'] │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:105 in │ │ │ │ 102 should_save(step) # Register that we jused saved. │ │ 103 │ │ 104 print('Start training loop.') │ │ ❱ 105 policy = lambda *args: agent.policy( │ │ 106 │ *args, mode='explore' if should_expl(step) else 'train') │ │ 107 while step < args.steps: │ │ 108 │ driver(policy, steps=100) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │ │ │ │ 72 │ │ @wraps(func) │ │ 73 │ │ def inner(*args, **kwds): │ │ 74 │ │ │ with self._recreate_cm(): │ │ ❱ 75 │ │ │ │ return func(*args, **kwds) │ │ 76 │ │ return inner │ │ 77 │ │ 78 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/jaxagent.py:62 in policy │ │ │ │ 59 │ state = tree_map( │ │ 60 │ │ np.asarray, state, is_leaf=lambda x: isinstance(x, list)) │ │ 61 │ state = self._convert_inps(state, self.policy_devices) │ │ ❱ 62 │ (outs, state), _ = self._policy(varibs, rng, obs, state, mode=mode) │ │ 63 │ outs = self._convert_outs(outs, self.policy_devices) │ │ 64 │ # TODO: Consider keeping policy states in accelerator memory. │ │ 65 │ state = self._convert_outs(state, self.policy_devices) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:199 in wrapper │ │ │ │ 196 │ statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static])) │ │ 197 │ kw = {k: v for k, v in kw.items() if k not in static} │ │ 198 │ if not hasattr(wrapper, 'keys'): │ │ ❱ 199 │ created = init(statics, rng, *args, **kw) │ │ 200 │ wrapper.keys = set(created.keys()) │ │ 201 │ for key, value in created.items(): │ │ 202 │ │ if key not in state: │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/traceback_util.py:16 │ │ 3 in reraise_with_filtered_traceback │ │ │ │ 160 def reraise_with_filtered_traceback(*args, **kwargs): │ │ 161 │ __tracebackhide__ = True │ │ 162 │ try: │ │ ❱ 163 │ return fun(*args, **kwargs) │ │ 164 │ except Exception as e: │ │ 165 │ mode = filtering_mode() │ │ 166 │ if is_under_reraiser(e) or mode == "off": │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:237 in │ │ cache_miss │ │ │ │ 234 │ │ 235 @api_boundary │ │ 236 def cache_miss(*args, **kwargs): │ │ ❱ 237 │ outs, out_flat, out_tree, args_flat = _python_pjit_helper( │ │ 238 │ │ fun, infer_params_fn, *args, **kwargs) │ │ 239 │ │ │ 240 │ executable = _read_most_recent_pjit_call_executable() │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:180 in │ │ _python_pjit_helper │ │ │ │ 177 │ │ 178 │ │ 179 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): │ │ ❱ 180 args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( │ │ 181 │ *args, **kwargs) │ │ 182 for arg in args_flat: │ │ 183 │ dispatch.check_arg(arg) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/api.py:443 in │ │ infer_params │ │ │ │ 440 │ │ static_argnames=static_argnames, donate_argnums=donate_argnums, │ │ 441 │ │ device=device, backend=backend, keep_unused=keep_unused, │ │ 442 │ │ inline=inline, resource_env=None) │ │ ❱ 443 │ return pjit.common_infer_params(pjit_info_args, *args, **kwargs) │ │ 444 │ │ │ 445 │ has_explicit_sharding = pjit._pjit_explicit_sharding( │ │ 446 │ │ in_shardings, out_shardings, device, backend) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:520 in │ │ common_infer_params │ │ │ │ 517 │ hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, │ │ 518 │ tuple(isinstance(a, GDA) for a in args_flat), resource_env) │ │ 519 │ │ ❱ 520 jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( │ │ 521 │ flat_fun, hashable_pytree(out_shardings), global_in_avals, │ │ 522 │ HashableFunction(out_tree, closure=()), │ │ 523 │ ('jit' if resource_env is None else 'pjit')) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:301 │ │ in memoized_fun │ │ │ │ 298 │ ans, stores = result │ │ 299 │ fun.populate_stores(stores) │ │ 300 │ else: │ │ ❱ 301 │ ans = call(fun, *args) │ │ 302 │ cache[key] = (ans, fun.stores) │ │ 303 │ │ │ 304 │ return ans │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:932 in │ │ _pjit_jaxpr │ │ │ │ 929 │ with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " │ │ 930 │ │ │ │ │ │ │ │ "for pjit in {elapsed_time} sec", │ │ 931 │ │ │ │ │ │ │ │ │ event=dispatch.JAXPR_TRACE_EVENT): │ │ ❱ 932 │ jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( │ │ 933 │ │ fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name)) │ │ 934 finally: │ │ 935 │ pxla.positional_semantics.val = prev_positional_val │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/profiler.py:314 in │ │ wrapper │ │ │ │ 311 @wraps(func) │ │ 312 def wrapper(*args, **kwargs): │ │ 313 │ with TraceAnnotation(name, **decorator_kwargs): │ │ ❱ 314 │ return func(*args, **kwargs) │ │ 315 │ return wrapper │ │ 316 return wrapper │ │ 317 │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │ │ .py:1985 in trace_to_jaxpr_dynamic │ │ │ │ 1982 │ │ │ │ │ │ keep_inputs: Optional[List[bool]] = None): │ │ 1983 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore │ │ 1984 │ main.jaxpr_stack = () # type: ignore │ │ ❱ 1985 │ jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( │ │ 1986 │ fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) │ │ 1987 │ del main, fun │ │ 1988 return jaxpr, out_avals, consts │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │ │ .py:2002 in trace_to_subjaxpr_dynamic │ │ │ │ 1999 │ trace = DynamicJaxprTrace(main, core.cur_sublevel()) │ │ 2000 │ in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) │ │ 2001 │ in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] │ │ ❱ 2002 │ ans = fun.call_wrapped(*in_tracers_) │ │ 2003 │ out_tracers = map(trace.full_raise, ans) │ │ 2004 │ jaxpr, consts = frame.to_jaxpr(out_tracers) │ │ 2005 │ del fun, main, trace, frame, in_tracers, out_tracers, ans │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:165 │ │ in call_wrapped │ │ │ │ 162 │ gen = gen_static_args = out_store = None │ │ 163 │ │ │ 164 │ try: │ │ ❱ 165 │ ans = self.f(*args, **dict(self.params, **kwargs)) │ │ 166 │ except: │ │ 167 │ # Some transformations yield from inside context managers, so we have to │ │ 168 │ # interrupt them before reraising the exception. Otherwise they will only │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:184 in init │ │ │ │ 181 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ 182 def init(statics, rng, *args, **kw): │ │ 183 │ # Return only state so JIT can remove dead code for fast initialization. │ │ ❱ 184 │ s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1] │ │ 185 │ return s │ │ 186 │ │ 187 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:95 in purified │ │ │ │ 92 │ │ rng = jax.random.PRNGKey(rng) │ │ 93 │ context = Context(state.copy(), rng, create, modify, ignore, [], name) │ │ 94 │ CONTEXT[threading.get_ident()] = context │ │ ❱ 95 │ out = fun(*args, **kwargs) │ │ 96 │ state = dict(context) │ │ 97 │ return out, state │ │ 98 │ finally: │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/agent.py:56 in policy │ │ │ │ 53 │ obs = self.preprocess(obs) │ │ 54 │ (prev_latent, prev_action), task_state, expl_state = state │ │ 55 │ embed = self.wm.encoder(obs) │ │ ❱ 56 │ latent, _ = self.wm.rssm.obs_step( │ │ 57 │ │ prev_latent, prev_action, embed, obs['is_first']) │ │ 58 │ self.expl_behavior.policy(latent, expl_state) │ │ 59 │ task_outs, task_state = self.task_behavior.policy(latent, task_state) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/nets.py:105 in obs_step │ │ │ │ 102 │ # # prior['deter'] has shape (1, 512) but embed has shape (1, 1, 1024). Need to sq │ │ 103 │ # embed = jnp.squeeze(embed, axis=1) │ │ 104 │ # print('aft: prior deter', prior['deter'], embed, prior['deter'].ndim == embed.ndim │ │ ❱ 105 │ x = jnp.concatenate([prior['deter'], embed], -1) │ │ 106 │ x = self.get('obs_out', Linear, **self._kw)(x) │ │ 107 │ stats = self._stats('obs_stats', x) │ │ 108 │ dist = self.get_dist(stats) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │ │ 845 in concatenate │ │ │ │ 1842 # (https://github.com/google/jax/issues/653). │ │ 1843 k = 16 │ │ 1844 while len(arrays_out) > 1: │ │ ❱ 1845 │ arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) │ │ 1846 │ │ │ │ for i in range(0, len(arrays_out), k)] │ │ 1847 return arrays_out[0] │ │ 1848 │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │ │ 845 in │ │ │ │ 1842 # (https://github.com/google/jax/issues/653). │ │ 1843 k = 16 │ │ 1844 while len(arrays_out) > 1: │ │ ❱ 1845 │ arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) │ │ 1846 │ │ │ │ for i in range(0, len(arrays_out), k)] │ │ 1847 return arrays_out[0] │ │ 1848 │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/lax.py:644 in │ │ concatenate │ │ │ │ 641 │ op, = operands │ │ 642 │ if isinstance(op, Array): │ │ 643 │ return type_cast(Array, op) │ │ ❱ 644 return concatenate_p.bind(*operands, dimension=dimension) │ │ 645 │ │ 646 │ │ 647 class _enum_descriptor: │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:343 in bind │ │ │ │ 340 def bind(self, *args, **params): │ │ 341 │ assert (not config.jax_enable_checks or │ │ 342 │ │ │ all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args │ │ ❱ 343 │ return self.bind_with_trace(find_top_trace(args), args, params) │ │ 344 │ │ 345 def bind_with_trace(self, trace, args, params): │ │ 346 │ out = trace.process_primitive(self, map(trace.full_raise, args), params) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:346 in │ │ bind_with_trace │ │ │ │ 343 │ return self.bind_with_trace(find_top_trace(args), args, params) │ │ 344 │ │ 345 def bind_with_trace(self, trace, args, params): │ │ ❱ 346 │ out = trace.process_primitive(self, map(trace.full_raise, args), params) │ │ 347 │ return map(full_lower, out) if self.multiple_results else full_lower(out) │ │ 348 │ │ 349 def def_impl(self, impl): │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │ │ .py:1721 in process_primitive │ │ │ │ 1718 def process_primitive(self, primitive, tracers, params): │ │ 1719 │ if primitive in custom_staging_rules: │ │ 1720 │ return custom_staging_rules[primitive](self, *tracers, **params) │ │ ❱ 1721 │ return self.default_process_primitive(primitive, tracers, params) │ │ 1722 │ │ 1723 def default_process_primitive(self, primitive, tracers, params): │ │ 1724 │ avals = [t.aval for t in tracers] │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │ │ .py:1725 in default_process_primitive │ │ │ │ 1722 │ │ 1723 def default_process_primitive(self, primitive, tracers, params): │ │ 1724 │ avals = [t.aval for t in tracers] │ │ ❱ 1725 │ out_avals, effects = primitive.abstract_eval(*avals, **params) │ │ 1726 │ out_avals = [out_avals] if not primitive.multiple_results else out_avals │ │ 1727 │ source_info = source_info_util.current() │ │ 1728 │ out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:379 in │ │ abstract_eval_ │ │ │ │ 376 │ │ 377 def _effect_free_abstract_eval(abstract_eval): │ │ 378 def abstract_eval_(*args, **kwargs): │ │ ❱ 379 │ return abstract_eval(*args, **kwargs), no_effects │ │ 380 return abstract_eval_ │ │ 381 │ │ 382 # -------------------- lifting -------------------- │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/utils.py:66 in │ │ standard_abstract_eval │ │ │ │ 63 │ out = prim.impl(*[x.val for x in avals], **kwargs) │ │ 64 │ return core.ConcreteArray(out.dtype, out, weak_type=weak_type) │ │ 65 elif least_specialized is core.ShapedArray: │ │ ❱ 66 │ return core.ShapedArray(shape_rule(*avals, **kwargs), │ │ 67 │ │ │ │ │ │ │ dtype_rule(*avals, **kwargs), weak_type=weak_type, │ │ 68 │ │ │ │ │ │ │ named_shape=named_shape_rule(*avals, **kwargs)) │ │ 69 elif least_specialized is core.DShapedArray: │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lax/lax.py:3014 in │ │ _concatenate_shape_rule │ │ │ │ 3011 │ raise TypeError(msg.format(type(op))) │ │ 3012 if len({operand.ndim for operand in operands}) != 1: │ │ 3013 │ msg = "Cannot concatenate arrays with different numbers of dimensions: got {}." │ │ ❱ 3014 │ raise TypeError(msg.format(", ".join(str(o.shape) for o in operands))) │ │ 3015 if not 0 <= dimension < operands[0].ndim: │ │ 3016 │ msg = "concatenate dimension out of bounds: dimension {} for shapes {}." │ │ 3017 │ raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands]))) │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ UnfilteredStackTrace: TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 512), (1, 1, 1024). The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/fabian/Desktop/dreamerv3/example.py:53 in │ │ │ │ 50 │ │ 51 │ │ 52 if __name__ == '__main__': │ │ ❱ 53 main() │ │ 54 │ │ │ │ /home/fabian/Desktop/dreamerv3/example.py:48 in main │ │ │ │ 45 args = embodied.Config( │ │ 46 │ **config.run, logdir=config.logdir, │ │ 47 │ batch_steps=config.batch_size * config.batch_length) │ │ ❱ 48 embodied.run.train(agent, env, replay, logger, args) │ │ 49 # embodied.run.eval_only(agent, env, logger, args) │ │ 50 │ │ 51 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:108 in train │ │ │ │ 105 policy = lambda *args: agent.policy( │ │ 106 │ *args, mode='explore' if should_expl(step) else 'train') │ │ 107 while step < args.steps: │ │ ❱ 108 │ driver(policy, steps=100) │ │ 109 │ if should_save(step): │ │ 110 │ checkpoint.save() │ │ 111 logger.write() │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:42 in __call__ │ │ │ │ 39 def __call__(self, policy, steps=0, episodes=0): │ │ 40 │ step, episode = 0, 0 │ │ 41 │ while step < steps or episode < episodes: │ │ ❱ 42 │ step, episode = self._step(policy, step, episode) │ │ 43 │ │ 44 def _step(self, policy, step, episode): │ │ 45 │ assert all(len(x) == len(self._env) for x in self._acts.values()) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/core/driver.py:50 in _step │ │ │ │ 47 │ obs = self._env.step(acts) │ │ 48 │ obs = {k: convert(v) for k, v in obs.items()} │ │ 49 │ assert all(len(x) == len(self._env) for x in obs.values()), obs │ │ ❱ 50 │ acts, self._state = policy(obs, self._state, **self._kwargs) │ │ 51 │ acts = {k: convert(v) for k, v in acts.items()} │ │ 52 │ if obs['is_last'].any(): │ │ 53 │ mask = 1 - obs['is_last'] │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/embodied/run/train.py:105 in │ │ │ │ 102 should_save(step) # Register that we jused saved. │ │ 103 │ │ 104 print('Start training loop.') │ │ ❱ 105 policy = lambda *args: agent.policy( │ │ 106 │ *args, mode='explore' if should_expl(step) else 'train') │ │ 107 while step < args.steps: │ │ 108 │ driver(policy, steps=100) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner │ │ │ │ 72 │ │ @wraps(func) │ │ 73 │ │ def inner(*args, **kwds): │ │ 74 │ │ │ with self._recreate_cm(): │ │ ❱ 75 │ │ │ │ return func(*args, **kwds) │ │ 76 │ │ return inner │ │ 77 │ │ 78 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/jaxagent.py:62 in policy │ │ │ │ 59 │ state = tree_map( │ │ 60 │ │ np.asarray, state, is_leaf=lambda x: isinstance(x, list)) │ │ 61 │ state = self._convert_inps(state, self.policy_devices) │ │ ❱ 62 │ (outs, state), _ = self._policy(varibs, rng, obs, state, mode=mode) │ │ 63 │ outs = self._convert_outs(outs, self.policy_devices) │ │ 64 │ # TODO: Consider keeping policy states in accelerator memory. │ │ 65 │ state = self._convert_outs(state, self.policy_devices) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:199 in wrapper │ │ │ │ 196 │ statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static])) │ │ 197 │ kw = {k: v for k, v in kw.items() if k not in static} │ │ 198 │ if not hasattr(wrapper, 'keys'): │ │ ❱ 199 │ created = init(statics, rng, *args, **kw) │ │ 200 │ wrapper.keys = set(created.keys()) │ │ 201 │ for key, value in created.items(): │ │ 202 │ │ if key not in state: │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:184 in init │ │ │ │ 181 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ 182 def init(statics, rng, *args, **kw): │ │ 183 │ # Return only state so JIT can remove dead code for fast initialization. │ │ ❱ 184 │ s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1] │ │ 185 │ return s │ │ 186 │ │ 187 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:95 in purified │ │ │ │ 92 │ │ rng = jax.random.PRNGKey(rng) │ │ 93 │ context = Context(state.copy(), rng, create, modify, ignore, [], name) │ │ 94 │ CONTEXT[threading.get_ident()] = context │ │ ❱ 95 │ out = fun(*args, **kwargs) │ │ 96 │ state = dict(context) │ │ 97 │ return out, state │ │ 98 │ finally: │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/agent.py:56 in policy │ │ │ │ 53 │ obs = self.preprocess(obs) │ │ 54 │ (prev_latent, prev_action), task_state, expl_state = state │ │ 55 │ embed = self.wm.encoder(obs) │ │ ❱ 56 │ latent, _ = self.wm.rssm.obs_step( │ │ 57 │ │ prev_latent, prev_action, embed, obs['is_first']) │ │ 58 │ self.expl_behavior.policy(latent, expl_state) │ │ 59 │ task_outs, task_state = self.task_behavior.policy(latent, task_state) │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/fabian/Desktop/dreamerv3/dreamerv3/nets.py:105 in obs_step │ │ │ │ 102 │ # # prior['deter'] has shape (1, 512) but embed has shape (1, 1, 1024). Need to sq │ │ 103 │ # embed = jnp.squeeze(embed, axis=1) │ │ 104 │ # print('aft: prior deter', prior['deter'], embed, prior['deter'].ndim == embed.ndim │ │ ❱ 105 │ x = jnp.concatenate([prior['deter'], embed], -1) │ │ 106 │ x = self.get('obs_out', Linear, **self._kw)(x) │ │ 107 │ stats = self._stats('obs_stats', x) │ │ 108 │ dist = self.get_dist(stats) │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │ │ 845 in concatenate │ │ │ │ 1842 # (https://github.com/google/jax/issues/653). │ │ 1843 k = 16 │ │ 1844 while len(arrays_out) > 1: │ │ ❱ 1845 │ arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) │ │ 1846 │ │ │ │ for i in range(0, len(arrays_out), k)] │ │ 1847 return arrays_out[0] │ │ 1848 │ │ │ │ /home/fabian/miniconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1 │ │ 845 in │ │ │ │ 1842 # (https://github.com/google/jax/issues/653). │ │ 1843 k = 16 │ │ 1844 while len(arrays_out) > 1: │ │ ❱ 1845 │ arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) │ │ 1846 │ │ │ │ for i in range(0, len(arrays_out), k)] │ │ 1847 return arrays_out[0] │ │ 1848 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 512), (1, 1, 1024). ```

Thanks

danijar commented 1 year ago

Maybe you're pointing to an old --logdir so it's trying to load a checkpoint that isn't compatible with the new version of your environment anymore?

defrag-bambino commented 1 year ago

Nope, I deleted them beforehand. I even tried with a freshly cloned repo.

Here's an excerpt of the last part of the error:

Creating new TensorBoard event file writer. Did not find any checkpoint. Writing checkpoint: logdir/run1/checkpoint.ckpt Start training loop. Saved chunk: 20230227T234456F327624-5imMN603dmNcTxUoqgfBmi-0000000000000000000000-76.npz Tracing policy function. Wrote checkpoint: logdir/run1/checkpoint.ckpt Error writing summary: stats/policy_image ---prior Traced<ShapedArray(float16[1,1024])>with<DynamicJaxprTrace(level=1/0)> ---embed Traced<ShapedArray(float16[1,1,1024])>with<DynamicJaxprTrace(level=1/0)>

I got the last two lines by printing the shape's of the prior['deter'] and embed arrays in nets.py line 94 (right before the concatenate call in line 95 fails). I am not quite sure what these two arrays are meant to contain, can you tell from this if these values seem off?

danijar commented 1 year ago

Ahh, I think I know the issue. When you use FromGym on an environment that returns a single observation array, rather than a dictionary of observations, then it uses image as the observation key by default. Later in the code, it then tries to write a video summary from this observation key, but it fails because the observation isn't shaped like an image. Can you try using FromGym(env, obs_key='vector') instead of FromGym(env)?

defrag-bambino commented 1 year ago

Hi, thanks for the reply. This is indeed a good point, I changed it. But this did apparently not resolve my issue. I think it is the following function call, for which the arguments do not have the same number of dimensions (see full error log above):

│ /home/fabian/Desktop/dreamerv3/dreamerv3/agent.py:56 in policy │ │ │ │ 53 │ obs = self.preprocess(obs) │ │ 54 │ (prev_latent, prev_action), task_state, explstate = state │ │ 55 │ embed = self.wm.encoder(obs) │ │ ❱ 56 │ latent, = self.wm.rssm.obs_step( │ │ 57 │ │ prev_latent, prev_action, embed, obs['is_first']) │ │ 58 │ self.expl_behavior.policy(latent, expl_state) │ │ 59 │ task_outs, task_state = self.task_behavior.policy(latent, task_state)

danijar commented 1 year ago

Can you share your full script (or ideally a simplified version), please?

defrag-bambino commented 1 year ago

Sure. I've created a fork and commited my changes to it. This also includes the Env, which in this case is Unity's PushBlock Example (using no visual observattions). I had to minimally adapt your code to get this Env to work. Check the diff to see them. https://github.com/defrag-bambino/dreamerv3-fork

danijar commented 1 year ago

Thanks but that's a bit too much custom code for me to have time to look into. Some ideas I can think of are that the environment does not return the shapes it claims in the observation space or again trying to load an incompatible checkpoint.

defrag-bambino commented 1 year ago

OK, thanks. I will try to debug it further and report back if there is any news.

defrag-bambino commented 1 year ago

Hi Danijar, I have found the issue. It was indeed my own code. I had written a wrapper that allowed to use multiple envs. Though I had it set to only one env for trying DreamerV3. However, the returned reward and observations were still a one-element list, instead of the actual observation. So basically [obs] instead of obs.

Thanks again for creating this awesome algorithm.