Closed 4ku closed 1 year ago
Traceback (most recent call last) ────────────────────────────────╮ │ /home/ivan/Desktop/drone/drone_landing/dreamer.py:52 in <module> │ │ │ │ 49 env = dreamerv3.wrap_env(env, config) │ │ 50 env = embodied.BatchEnv([env], parallel=False) │ │ 51 │ │ ❱ 52 agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config) │ │ 53 replay = embodied.replay.Uniform( │ │ 54 │ config.batch_length, config.replay_size, logdir / 'replay') │ │ 55 args = embodied.Config( │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:20 in __init__ │ │ │ │ 17 │ configs = agent_cls.configs │ │ 18 │ inner = agent_cls │ │ 19 │ def __init__(self, *args, **kwargs): │ │ ❱ 20 │ super().__init__(agent_cls, *args, **kwargs) │ │ 21 return Agent │ │ 22 │ │ 23 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:48 in __init__ │ │ │ │ 45 │ self._updates = embodied.Counter() │ │ 46 │ self._should_metrics = embodied.when.Every(self.config.metrics_every) │ │ 47 │ self._transform() │ │ ❱ 48 │ self.varibs = self._init_varibs(obs_space, act_space) │ │ 49 │ self.sync() │ │ 50 │ │ 51 def policy(self, obs, state=None, mode='train'): │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxagent.py:228 in _init_varibs │ │ │ │ 225 │ data = self._dummy_batch({**obs_space, **act_space}, dims) │ │ 226 │ data = self._convert_inps(data, self.train_devices) │ │ 227 │ state, varibs = self._init_train(varibs, rng, data['is_first']) │ │ ❱ 228 │ varibs = self._train(varibs, rng, data, state, init_only=True) │ │ 229 │ # obs = self._dummy_batch(obs_space, (1,)) │ │ 230 │ # state, varibs = self._init_policy(varibs, rng, obs['is_first']) │ │ 231 │ # varibs = self._policy( │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:199 in wrapper │ │ │ │ 196 │ statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static])) │ │ 197 │ kw = {k: v for k, v in kw.items() if k not in static} │ │ 198 │ if not hasattr(wrapper, 'keys'): │ │ ❱ 199 │ created = init(statics, rng, *args, **kw) │ │ 200 │ wrapper.keys = set(created.keys()) │ │ 201 │ for key, value in created.items(): │ │ 202 │ │ if key not in state: │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/traceback_util.py:166 in │ │ reraise_with_filtered_traceback │ │ │ │ 163 def reraise_with_filtered_traceback(*args, **kwargs): │ │ 164 │ __tracebackhide__ = True │ │ 165 │ try: │ │ ❱ 166 │ return fun(*args, **kwargs) │ │ 167 │ except Exception as e: │ │ 168 │ mode = _filtering_mode() │ │ 169 │ if _is_under_reraiser(e) or mode == "off": │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:253 in cache_miss │ │ │ │ 250 │ │ 251 @api_boundary │ │ 252 def cache_miss(*args, **kwargs): │ │ ❱ 253 │ outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( │ │ 254 │ │ fun, infer_params_fn, *args, **kwargs) │ │ 255 │ executable = _read_most_recent_pjit_call_executable(jaxpr) │ │ 256 │ fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat) │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:161 in │ │ _python_pjit_helper │ │ │ │ 158 │ │ 159 │ │ 160 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): │ │ ❱ 161 args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( │ │ 162 │ *args, **kwargs) │ │ 163 for arg in args_flat: │ │ 164 │ dispatch.check_arg(arg) │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/api.py:324 in infer_params │ │ │ │ 321 │ │ donate_argnames=donate_argnames, device=device, backend=backend, │ │ 322 │ │ keep_unused=keep_unused, inline=inline, resource_env=None, │ │ 323 │ │ abstracted_axes=abstracted_axes) │ │ ❱ 324 │ return pjit.common_infer_params(pjit_info_args, *args, **kwargs) │ │ 325 │ │ 326 has_explicit_sharding = pjit._pjit_explicit_sharding( │ │ 327 │ in_shardings, out_shardings, device, backend) │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:491 in │ │ common_infer_params │ │ │ │ 488 │ hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg, │ │ 489 │ device_or_backend_set) │ │ 490 │ │ ❱ 491 jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( │ │ 492 │ flat_fun, hashable_pytree(out_shardings), in_type, dbg, │ │ 493 │ device_or_backend_set, HashableFunction(out_tree, closure=()), │ │ 494 │ HashableFunction(res_paths, closure=())) │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:969 in _pjit_jaxpr │ │ │ │ 966 │ │ 967 def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, │ │ 968 │ │ │ │ device_or_backend_set, out_tree, result_paths): │ │ ❱ 969 jaxpr, final_consts, out_type = _create_pjit_jaxpr( │ │ 970 │ fun, in_type, debug_info, result_paths) │ │ 971 canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings( │ │ 972 │ out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info, │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/linear_util.py:345 in │ │ memoized_fun │ │ │ │ 342 │ ans, stores = result │ │ 343 │ fun.populate_stores(stores) │ │ 344 │ else: │ │ ❱ 345 │ ans = call(fun, *args) │ │ 346 │ cache[key] = (ans, fun.stores) │ │ 347 │ │ │ 348 │ return ans │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/pjit.py:922 in │ │ _create_pjit_jaxpr │ │ │ │ 919 │ jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( │ │ 920 │ │ lu.annotate(fun, in_type), debug_info=pe_debug) │ │ 921 │ else: │ │ ❱ 922 │ jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( │ │ 923 │ │ fun, in_type, debug_info=pe_debug) │ │ 924 │ │ 925 if not config.jax_dynamic_shapes: │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/profiler.py:314 in wrapper │ │ │ │ 311 @wraps(func) │ │ 312 def wrapper(*args, **kwargs): │ │ 313 │ with TraceAnnotation(name, **decorator_kwargs): │ │ ❱ 314 │ return func(*args, **kwargs) │ │ 315 │ return wrapper │ │ 316 return wrapper │ │ 317 │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py: │ │ 2155 in trace_to_jaxpr_dynamic │ │ │ │ 2152 ) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: │ │ 2153 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore │ │ 2154 │ main.jaxpr_stack = () # type: ignore │ │ ❱ 2155 │ jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( │ │ 2156 │ fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) │ │ 2157 │ del main, fun │ │ 2158 return jaxpr, out_avals, consts │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py: │ │ 2177 in trace_to_subjaxpr_dynamic │ │ │ │ 2174 │ trace = DynamicJaxprTrace(main, core.cur_sublevel()) │ │ 2175 │ in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) │ │ 2176 │ in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] │ │ ❱ 2177 │ ans = fun.call_wrapped(*in_tracers_) │ │ 2178 │ out_tracers = map(trace.full_raise, ans) │ │ 2179 │ jaxpr, consts = frame.to_jaxpr(out_tracers) │ │ 2180 │ del fun, main, trace, frame, in_tracers, out_tracers, ans │ │ │ │ /home/ivan/Desktop/drone/env/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in │ │ call_wrapped │ │ │ │ 185 │ gen = gen_static_args = out_store = None │ │ 186 │ │ │ 187 │ try: │ │ ❱ 188 │ ans = self.f(*args, **dict(self.params, **kwargs)) │ │ 189 │ except: │ │ 190 │ # Some transformations yield from inside context managers, so we have to │ │ 191 │ # interrupt them before reraising the exception. Otherwise they will only │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:184 in init │ │ │ │ 181 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ 182 def init(statics, rng, *args, **kw): │ │ 183 │ # Return only state so JIT can remove dead code for fast initialization. │ │ ❱ 184 │ s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1] │ │ 185 │ return s │ │ 186 │ │ 187 @bind(jax.jit, static_argnums=[0], **kwargs) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:95 in purified │ │ │ │ 92 │ │ rng = jax.random.PRNGKey(rng) │ │ 93 │ context = Context(state.copy(), rng, create, modify, ignore, [], name) │ │ 94 │ CONTEXT[threading.get_ident()] = context │ │ ❱ 95 │ out = fun(*args, **kwargs) │ │ 96 │ state = dict(context) │ │ 97 │ return out, state │ │ 98 │ finally: │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:80 in train │ │ │ │ 77 │ self.config.jax.jit and print('Tracing train function.') │ │ 78 │ metrics = {} │ │ 79 │ data = self.preprocess(data) │ │ ❱ 80 │ state, wm_outs, mets = self.wm.train(data, state) │ │ 81 │ metrics.update(mets) │ │ 82 │ context = {**data, **wm_outs['post']} │ │ 83 │ start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:146 in train │ │ │ │ 143 │ │ 144 def train(self, data, state): │ │ 145 │ modules = [self.encoder, self.rssm, *self.heads.values()] │ │ ❱ 146 │ mets, (state, outs, metrics) = self.opt( │ │ 147 │ │ modules, self.loss, data, state, has_aux=True) │ │ 148 │ metrics.update(mets) │ │ 149 │ return state, outs, metrics │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxutils.py:390 in __call__ │ │ │ │ 387 │ │ loss *= sg(self.grad_scale.read()) │ │ 388 │ return loss, aux │ │ 389 │ metrics = {} │ │ ❱ 390 │ loss, params, grads, aux = nj.grad( │ │ 391 │ │ wrapped, modules, has_aux=True)(*args, **kwargs) │ │ 392 │ if not self.PARAM_COUNTS[self.path]: │ │ 393 │ count = sum([np.prod(x.shape) for x in params.values()]) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:160 in wrapper │ │ │ │ 157 backward = jax.value_and_grad(forward, has_aux=True) │ │ 158 @functools.wraps(backward) │ │ 159 def wrapper(*args, **kwargs): │ │ ❱ 160 │ _prerun(fun, *args, **kwargs) │ │ 161 │ assert all(isinstance(x, (str, Module)) for x in keys) │ │ 162 │ strs = [x for x in keys if isinstance(x, str)] │ │ 163 │ mods = [x for x in keys if isinstance(x, Module)] │ │ │ │ /usr/lib/python3.10/contextlib.py:79 in inner │ │ │ │ 76 │ │ @wraps(func) │ │ 77 │ │ def inner(*args, **kwds): │ │ 78 │ │ │ with self._recreate_cm(): │ │ ❱ 79 │ │ │ │ return func(*args, **kwds) │ │ 80 │ │ return inner │ │ 81 │ │ 82 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:297 in _prerun │ │ │ │ 294 def _prerun(fun, *args, **kwargs): │ │ 295 if not context().create: │ │ 296 │ return │ │ ❱ 297 discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs) │ │ 298 # jax.tree_util.tree_map( │ │ 299 # lambda x: hasattr(x, 'delete') and x.delete(), discarded) │ │ 300 context().update(state) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:95 in purified │ │ │ │ 92 │ │ rng = jax.random.PRNGKey(rng) │ │ 93 │ context = Context(state.copy(), rng, create, modify, ignore, [], name) │ │ 94 │ CONTEXT[threading.get_ident()] = context │ │ ❱ 95 │ out = fun(*args, **kwargs) │ │ 96 │ state = dict(context) │ │ 97 │ return out, state │ │ 98 │ finally: │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/jaxutils.py:382 in wrapped │ │ │ │ 379 │ │ 380 def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs): │ │ 381 │ def wrapped(*args, **kwargs): │ │ ❱ 382 │ outs = lossfn(*args, **kwargs) │ │ 383 │ loss, aux = outs if has_aux else (outs, None) │ │ 384 │ assert loss.dtype == jnp.float32, (self.name, loss.dtype) │ │ 385 │ assert loss.shape == (), (self.name, loss.shape) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/agent.py:161 in loss │ │ │ │ 158 │ dists = {} │ │ 159 │ feats = {**post, 'embed': embed} │ │ 160 │ for name, head in self.heads.items(): │ │ ❱ 161 │ out = head(feats if name in self.config.grad_heads else sg(feats)) │ │ 162 │ out = out if isinstance(out, dict) else {name: out} │ │ 163 │ dists.update(out) │ │ 164 │ losses = {} │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/nets.py:283 in __call__ │ │ │ │ 280 │ if drop_loss_indices is not None: │ │ 281 │ │ feat = feat[:, drop_loss_indices] │ │ 282 │ flat = feat.reshape([-1, feat.shape[-1]]) │ │ ❱ 283 │ output = self._cnn(flat) │ │ 284 │ output = output.reshape(feat.shape[:-1] + output.shape[1:]) │ │ 285 │ split_indices = np.cumsum([v[-1] for v in self.cnn_shapes.values()][:-1]) │ │ 286 │ means = jnp.split(output, split_indices, -1) │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │ │ │ │ 377 def wrapper(self, *args, **kwargs): │ │ 378 │ with scope(self._path, absolute=True): │ │ 379 │ with jax.named_scope(self._path.split('/')[-1]): │ │ ❱ 380 │ │ return method(self, *args, **kwargs) │ │ 381 return wrapper │ │ 382 │ │ 383 │ │ │ │ /home/ivan/Desktop/drone/dreamerv3/dreamerv3/nets.py:396 in __call__ │ │ │ │ 393 │ x = x[:, int(np.ceil(padh)): -int(padh), :] │ │ 394 │ x = x[:, :, int(np.ceil(padw)): -int(padw)] │ │ 395 │ # print(x.shape) │ │ ❱ 396 │ assert x.shape[-3:] == self._shape, (x.shape, self._shape) │ │ 397 │ if self._sigmoid: │ │ 398 │ x = jax.nn.sigmoid(x) │ │ 399 │ else: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ UnfilteredStackTrace: AssertionError: ((1024, 64, 64, 4), (48, 88, 4))
Solved - #12 comment