Closed vaggelisT closed 5 months ago
Hello,
I have changed the model size in example.py file from medium to large. config = config.update(dreamerv3.configs['large'])
config = config.update(dreamerv3.configs['large'])
but i get this error:
...... /envs/TestingEnv/lib/python3.8/site-packages/jax/_src/numpy/ufuncs.py:97 │ │ in fn │ │ │ │ 94 │ lax_doc: bool = False) -> BinOp: │ │ 95 def fn(x1, x2, /): │ │ 96 │ x1, x2 = promote_args(numpy_fn.name, x1, x2) │ │ ❱ 97 │ return laxfn(x1, x2) if x1.dtype != np.bool else bool_lax_fn(x1, x2) │ │ 98 fn.qualname = f"jax.numpy.{numpy_fn.name}" │ │ 99 fn = jit(fn, inline=True) │ │ 100 if lax_doc: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ TypeError: add got incompatible shapes for broadcasting: (1, 1024), (1, 2048).
any suggestion?
You most likely didn't set a different logging directory for the run and the model is trying to load an incompatible checkpoint automatically. I also spent some time on this before realizing it...
Hello,
I have changed the model size in example.py file from medium to large.
config = config.update(dreamerv3.configs['large'])
but i get this error:
...... /envs/TestingEnv/lib/python3.8/site-packages/jax/_src/numpy/ufuncs.py:97 │ │ in fn │ │ │ │ 94 │ lax_doc: bool = False) -> BinOp: │ │ 95 def fn(x1, x2, /): │ │ 96 │ x1, x2 = promote_args(numpy_fn.name, x1, x2) │ │ ❱ 97 │ return laxfn(x1, x2) if x1.dtype != np.bool else bool_lax_fn(x1, x2) │ │ 98 fn.qualname = f"jax.numpy.{numpy_fn.name}" │ │ 99 fn = jit(fn, inline=True) │ │ 100 if lax_doc: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ TypeError: add got incompatible shapes for broadcasting: (1, 1024), (1, 2048).
any suggestion?