danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.36k stars 229 forks source link

UnfilteredStackTrace: AssertionError #77

Closed 4ku closed 1 year ago

4ku commented 1 year ago
Traceback (most recent call last) ────────────────────────────────╮
│ /home/ivan/Desktop/drone/drone_landing/dreamer.py:52 in <module>                                 │
│                                                                                                  │
│   49 env = dreamerv3.wrap_env(env, config)                                                       │
│   50 env = embodied.BatchEnv([env], parallel=False)                                              │
│   51                                                                                             │
│ ❱ 52 agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)                         │
│   53 replay = embodied.replay.Uniform(                                                           │
│   54 │   config.batch_length, config.replay_size, logdir / 'replay')                             │
│   55 args = embodied.Config(                                                                     │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:20 in __init__                          │
│                                                                                                  │
│    17 │   configs = agent_cls.configs                                                            │
│    18 │   inner = agent_cls                                                                      │
│    19 │   def __init__(self, *args, **kwargs):                                                   │
│ ❱  20 │     super().__init__(agent_cls, *args, **kwargs)                                         │
│    21   return Agent                                                                             │
│    22                                                                                            │
│    23                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:48 in __init__                          │
│                                                                                                  │
│    45 │   self._updates = embodied.Counter()                                                     │
│    46 │   self._should_metrics = embodied.when.Every(self.config.metrics_every)                  │
│    47 │   self._transform()                                                                      │
│ ❱  48 │   self.varibs = self._init_varibs(obs_space, act_space)                                  │
│    49 │   self.sync()                                                                            │
│    50                                                                                            │
│    51   def policy(self, obs, state=None, mode='train'):                                         │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:228 in _init_varibs                     │
│                                                                                                  │
│   225 │   data = self._dummy_batch({**obs_space, **act_space}, dims)                             │
│   226 │   data = self._convert_inps(data, self.train_devices)                                    │
│   227 │   state, varibs = self._init_train(varibs, rng, data['is_first'])                        │
│ ❱ 228 │   varibs = self._train(varibs, rng, data, state, init_only=True)                         │
│   229 │   # obs = self._dummy_batch(obs_space, (1,))                                             │
│   230 │   # state, varibs = self._init_policy(varibs, rng, obs['is_first'])                      │
│   231 │   # varibs = self._policy(                                                               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:199 in wrapper                            │
│                                                                                                  │
│   196 │   statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static]))                │
│   197 │   kw = {k: v for k, v in kw.items() if k not in static}                                  │
│   198 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 199 │     created = init(statics, rng, *args, **kw)                                            │
│   200 │     wrapper.keys = set(created.keys())                                                   │
│   201 │     for key, value in created.items():                                                   │
│   202 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/traceback_util.py:166 in      │
│ reraise_with_filtered_traceback                                                                  │
│                                                                                                  │
│   163   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   164 │   __tracebackhide__ = True                                                               │
│   165 │   try:                                                                                   │
│ ❱ 166 │     return fun(*args, **kwargs)                                                          │
│   167 │   except Exception as e:                                                                 │
│   168 │     mode = _filtering_mode()                                                             │
│   169 │     if _is_under_reraiser(e) or mode == "off":                                           │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:253 in cache_miss     │
│                                                                                                  │
│    250                                                                                           │
│    251   @api_boundary                                                                           │
│    252   def cache_miss(*args, **kwargs):                                                        │
│ ❱  253 │   outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(                     │
│    254 │   │   fun, infer_params_fn, *args, **kwargs)                                            │
│    255 │   executable = _read_most_recent_pjit_call_executable(jaxpr)                            │
│    256 │   fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)         │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:161 in                │
│ _python_pjit_helper                                                                              │
│                                                                                                  │
│    158                                                                                           │
│    159                                                                                           │
│    160 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):                           │
│ ❱  161   args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(                           │
│    162 │     *args, **kwargs)                                                                    │
│    163   for arg in args_flat:                                                                   │
│    164 │   dispatch.check_arg(arg)                                                               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/api.py:324 in infer_params    │
│                                                                                                  │
│    321 │   │   donate_argnames=donate_argnames, device=device, backend=backend,                  │
│    322 │   │   keep_unused=keep_unused, inline=inline, resource_env=None,                        │
│    323 │   │   abstracted_axes=abstracted_axes)                                                  │
│ ❱  324 │   return pjit.common_infer_params(pjit_info_args, *args, **kwargs)                      │
│    325                                                                                           │
│    326   has_explicit_sharding = pjit._pjit_explicit_sharding(                                   │
│    327 │     in_shardings, out_shardings, device, backend)                                       │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:491 in                │
│ common_infer_params                                                                              │
│                                                                                                  │
│    488 │     hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg,                │
│    489 │     device_or_backend_set)                                                              │
│    490                                                                                           │
│ ❱  491   jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(                          │
│    492 │     flat_fun, hashable_pytree(out_shardings), in_type, dbg,                             │
│    493 │     device_or_backend_set, HashableFunction(out_tree, closure=()),                      │
│    494 │     HashableFunction(res_paths, closure=()))                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:969 in _pjit_jaxpr    │
│                                                                                                  │
│    966                                                                                           │
│    967 def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info,                            │
│    968 │   │   │   │   device_or_backend_set, out_tree, result_paths):                           │
│ ❱  969   jaxpr, final_consts, out_type = _create_pjit_jaxpr(                                     │
│    970 │     fun, in_type, debug_info, result_paths)                                             │
│    971   canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(               │
│    972 │     out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info,             │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/linear_util.py:345 in         │
│ memoized_fun                                                                                     │
│                                                                                                  │
│   342 │     ans, stores = result                                                                 │
│   343 │     fun.populate_stores(stores)                                                          │
│   344 │   else:                                                                                  │
│ ❱ 345 │     ans = call(fun, *args)                                                               │
│   346 │     cache[key] = (ans, fun.stores)                                                       │
│   347 │                                                                                          │
│   348 │   return ans                                                                             │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:922 in                │
│ _create_pjit_jaxpr                                                                               │
│                                                                                                  │
│    919 │     jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(                       │
│    920 │   │     lu.annotate(fun, in_type), debug_info=pe_debug)                                 │
│    921 │   else:                                                                                 │
│ ❱  922 │     jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(                        │
│    923 │   │     fun, in_type, debug_info=pe_debug)                                              │
│    924                                                                                           │
│    925   if not config.jax_dynamic_shapes:                                                       │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/profiler.py:314 in wrapper    │
│                                                                                                  │
│   311   @wraps(func)                                                                             │
│   312   def wrapper(*args, **kwargs):                                                            │
│   313 │   with TraceAnnotation(name, **decorator_kwargs):                                        │
│ ❱ 314 │     return func(*args, **kwargs)                                                         │
│   315 │   return wrapper                                                                         │
│   316   return wrapper                                                                           │
│   317                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py: │
│ 2155 in trace_to_jaxpr_dynamic                                                                   │
│                                                                                                  │
│   2152 ) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:                                        │
│   2153   with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore            │
│   2154 │   main.jaxpr_stack = ()  # type: ignore                                                 │
│ ❱ 2155 │   jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(                                 │
│   2156 │     fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)                │
│   2157 │   del main, fun                                                                         │
│   2158   return jaxpr, out_avals, consts                                                         │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py: │
│ 2177 in trace_to_subjaxpr_dynamic                                                                │
│                                                                                                  │
│   2174 │   trace = DynamicJaxprTrace(main, core.cur_sublevel())                                  │
│   2175 │   in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)                          │
│   2176 │   in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]                 │
│ ❱ 2177 │   ans = fun.call_wrapped(*in_tracers_)                                                  │
│   2178 │   out_tracers = map(trace.full_raise, ans)                                              │
│   2179 │   jaxpr, consts = frame.to_jaxpr(out_tracers)                                           │
│   2180 │   del fun, main, trace, frame, in_tracers, out_tracers, ans                             │
│                                                                                                  │
│ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in         │
│ call_wrapped                                                                                     │
│                                                                                                  │
│   185 │   gen = gen_static_args = out_store = None                                               │
│   186 │                                                                                          │
│   187 │   try:                                                                                   │
│ ❱ 188 │     ans = self.f(*args, **dict(self.params, **kwargs))                                   │
│   189 │   except:                                                                                │
│   190 │     # Some transformations yield from inside context managers, so we have to             │
│   191 │     # interrupt them before reraising the exception. Otherwise they will only            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:184 in init                               │
│                                                                                                  │
│   181   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   182   def init(statics, rng, *args, **kw):                                                     │
│   183 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 184 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1]                         │
│   185 │   return s                                                                               │
│   186                                                                                            │
│   187   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:95 in purified                            │
│                                                                                                  │
│    92 │   │   rng = jax.random.PRNGKey(rng)                                                      │
│    93 │     context = Context(state.copy(), rng, create, modify, ignore, [], name)               │
│    94 │     CONTEXT[threading.get_ident()] = context                                             │
│ ❱  95 │     out = fun(*args, **kwargs)                                                           │
│    96 │     state = dict(context)                                                                │
│    97 │     return out, state                                                                    │
│    98 │   finally:                                                                               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:80 in train                                │
│                                                                                                  │
│    77 │   self.config.jax.jit and print('Tracing train function.')                               │
│    78 │   metrics = {}                                                                           │
│    79 │   data = self.preprocess(data)                                                           │
│ ❱  80 │   state, wm_outs, mets = self.wm.train(data, state)                                      │
│    81 │   metrics.update(mets)                                                                   │
│    82 │   context = {**data, **wm_outs['post']}                                                  │
│    83 │   start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:146 in train                               │
│                                                                                                  │
│   143                                                                                            │
│   144   def train(self, data, state):                                                            │
│   145 │   modules = [self.encoder, self.rssm, *self.heads.values()]                              │
│ ❱ 146 │   mets, (state, outs, metrics) = self.opt(                                               │
│   147 │   │   modules, self.loss, data, state, has_aux=True)                                     │
│   148 │   metrics.update(mets)                                                                   │
│   149 │   return state, outs, metrics                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxutils.py:390 in __call__                         │
│                                                                                                  │
│   387 │   │   loss *= sg(self.grad_scale.read())                                                 │
│   388 │     return loss, aux                                                                     │
│   389 │   metrics = {}                                                                           │
│ ❱ 390 │   loss, params, grads, aux = nj.grad(                                                    │
│   391 │   │   wrapped, modules, has_aux=True)(*args, **kwargs)                                   │
│   392 │   if not self.PARAM_COUNTS[self.path]:                                                   │
│   393 │     count = sum([np.prod(x.shape) for x in params.values()])                             │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:160 in wrapper                            │
│                                                                                                  │
│   157   backward = jax.value_and_grad(forward, has_aux=True)                                     │
│   158   @functools.wraps(backward)                                                               │
│   159   def wrapper(*args, **kwargs):                                                            │
│ ❱ 160 │   _prerun(fun, *args, **kwargs)                                                          │
│   161 │   assert all(isinstance(x, (str, Module)) for x in keys)                                 │
│   162 │   strs = [x for x in keys if isinstance(x, str)]                                         │
│   163 │   mods = [x for x in keys if isinstance(x, Module)]                                      │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:79 in inner                                                    │
│                                                                                                  │
│    76 │   │   @wraps(func)                                                                       │
│    77 │   │   def inner(*args, **kwds):                                                          │
│    78 │   │   │   with self._recreate_cm():                                                      │
│ ❱  79 │   │   │   │   return func(*args, **kwds)                                                 │
│    80 │   │   return inner                                                                       │
│    81                                                                                            │
│    82                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:297 in _prerun                            │
│                                                                                                  │
│   294 def _prerun(fun, *args, **kwargs):                                                         │
│   295   if not context().create:                                                                 │
│   296 │   return                                                                                 │
│ ❱ 297   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   298   # jax.tree_util.tree_map(                                                                │
│   299   #     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                          │
│   300   context().update(state)                                                                  │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:95 in purified                            │
│                                                                                                  │
│    92 │   │   rng = jax.random.PRNGKey(rng)                                                      │
│    93 │     context = Context(state.copy(), rng, create, modify, ignore, [], name)               │
│    94 │     CONTEXT[threading.get_ident()] = context                                             │
│ ❱  95 │     out = fun(*args, **kwargs)                                                           │
│    96 │     state = dict(context)                                                                │
│    97 │     return out, state                                                                    │
│    98 │   finally:                                                                               │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxutils.py:382 in wrapped                          │
│                                                                                                  │
│   379                                                                                            │
│   380   def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs):                     │
│   381 │   def wrapped(*args, **kwargs):                                                          │
│ ❱ 382 │     outs = lossfn(*args, **kwargs)                                                       │
│   383 │     loss, aux = outs if has_aux else (outs, None)                                        │
│   384 │     assert loss.dtype == jnp.float32, (self.name, loss.dtype)                            │
│   385 │     assert loss.shape == (), (self.name, loss.shape)                                     │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:161 in loss                                │
│                                                                                                  │
│   158 │   dists = {}                                                                             │
│   159 │   feats = {**post, 'embed': embed}                                                       │
│   160 │   for name, head in self.heads.items():                                                  │
│ ❱ 161 │     out = head(feats if name in self.config.grad_heads else sg(feats))                   │
│   162 │     out = out if isinstance(out, dict) else {name: out}                                  │
│   163 │     dists.update(out)                                                                    │
│   164 │   losses = {}                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/nets.py:283 in __call__                             │
│                                                                                                  │
│   280 │     if drop_loss_indices is not None:                                                    │
│   281 │   │   feat = feat[:, drop_loss_indices]                                                  │
│   282 │     flat = feat.reshape([-1, feat.shape[-1]])                                            │
│ ❱ 283 │     output = self._cnn(flat)                                                             │
│   284 │     output = output.reshape(feat.shape[:-1] + output.shape[1:])                          │
│   285 │     split_indices = np.cumsum([v[-1] for v in self.cnn_shapes.values()][:-1])            │
│   286 │     means = jnp.split(output, split_indices, -1)                                         │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper                            │
│                                                                                                  │
│   377   def wrapper(self, *args, **kwargs):                                                      │
│   378 │   with scope(self._path, absolute=True):                                                 │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                                               │
│   381   return wrapper                                                                           │
│   382                                                                                            │
│   383                                                                                            │
│                                                                                                  │
│ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/nets.py:396 in __call__                             │
│                                                                                                  │
│   393 │     x = x[:, int(np.ceil(padh)): -int(padh), :]                                          │
│   394 │     x = x[:, :, int(np.ceil(padw)): -int(padw)]                                          │
│   395 │   # print(x.shape)                                                                       │
│ ❱ 396 │   assert x.shape[-3:] == self._shape, (x.shape, self._shape)                             │
│   397 │   if self._sigmoid:                                                                      │
│   398 │     x = jax.nn.sigmoid(x)                                                                │
│   399 │   else:                                                                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: AssertionError: ((1024, 64, 64, 4), (48, 88, 4))
4ku commented 1 year ago

Solved - #12 comment