danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 218 forks source link

Bug: jaxutils.Optimizer.PARAM_COUNTS parameters counts is None #111

Closed miniwa closed 5 months ago

miniwa commented 6 months ago

I'm trying to run the provided example.py with the default config and crafter environment. Runtime: Google Colab

Running into this error:

Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
JAX devices (1): [cuda(id=0)]
Policy devices: cuda:0
Train devices:  cuda:0
Tracing train function.
Optimizer model_opt has 32,468,675 variables.
Optimizer actor_opt has 2,144,657 variables.
Optimizer critic_opt has 2,297,215 variables.
Logdir /root/logdir/run3
Observation space:
  image            Space(dtype=uint8, shape=(64, 64, 3), low=0, high=255)
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
Action space:
  action           Space(dtype=float32, shape=(17,), low=0, high=1)
  reset            Space(dtype=bool, shape=(), low=False, high=True)
Prefill train dataset.
Episode has 153 steps and return 3.1.
Episode has 154 steps and return 1.1.
Episode has 169 steps and return 1.1.
Episode has 112 steps and return 0.1.
Episode has 163 steps and return 1.1.
Episode has 151 steps and return 0.1.
Saved chunk: 20240305T100157F260248-1d7rWmtZlxbtWabRys3dP5-67BsmQvl9pxqVtLv89IlWL-1024.npz
Episode has 183 steps and return 2.1.
──────────────────────────────────────────────────── Step 1100 ────────────────────────────────────────────────────
episode/length 183 / episode/score 2.1 / episode/sum_abs_reward 4.1 / episode/reward_rate 0.02

Creating new TensorBoard event file writer.
Did not find any checkpoint.
Writing checkpoint: /root/logdir/run3/checkpoint.ckpt
Start training loop.
Tracing policy function.
Saved chunk: 20240305T100206F497031-67BsmQvl9pxqVtLv89IlWL-0000000000000000000000-76.npz
Wrote checkpoint: /root/logdir/run3/checkpoint.ckpt
Tracing policy function.
Tracing train function.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 54>:54                                                                            │
│ in main:51                                                                                       │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/embodied/run/train.py:108 in train             │
│                                                                                                  │
│   105   policy = lambda *args: agent.policy(                                                     │
│   106 │     *args, mode='explore' if should_expl(step) else 'train')                             │
│   107   while step < args.steps:                                                                 │
│ ❱ 108 │   driver(policy, steps=100)                                                              │
│   109 │   if should_save(step):                                                                  │
│   110 │     checkpoint.save()                                                                    │
│   111   logger.write()                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/embodied/core/driver.py:42 in __call__         │
│                                                                                                  │
│   39   def __call__(self, policy, steps=0, episodes=0):                                          │
│   40 │   step, episode = 0, 0                                                                    │
│   41 │   while step < steps or episode < episodes:                                               │
│ ❱ 42 │     step, episode = self._step(policy, step, episode)                                     │
│   43                                                                                             │
│   44   def _step(self, policy, step, episode):                                                   │
│   45 │   assert all(len(x) == len(self._env) for x in self._acts.values())                       │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/embodied/core/driver.py:65 in _step            │
│                                                                                                  │
│   62 │   for i in range(len(self._env)):                                                         │
│   63 │     trn = {k: v[i] for k, v in trns.items()}                                              │
│   64 │     [self._eps[i][k].append(v) for k, v in trn.items()]                                   │
│ ❱ 65 │     [fn(trn, i, **self._kwargs) for fn in self._on_steps]                                 │
│   66 │     step += 1                                                                             │
│   67 │   if obs['is_last'].any():                                                                │
│   68 │     for i, done in enumerate(obs['is_last']):                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/embodied/core/driver.py:65 in <listcomp>       │
│                                                                                                  │
│   62 │   for i in range(len(self._env)):                                                         │
│   63 │     trn = {k: v[i] for k, v in trns.items()}                                              │
│   64 │     [self._eps[i][k].append(v) for k, v in trn.items()]                                   │
│ ❱ 65 │     [fn(trn, i, **self._kwargs) for fn in self._on_steps]                                 │
│   66 │     step += 1                                                                             │
│   67 │   if obs['is_last'].any():                                                                │
│   68 │     for i, done in enumerate(obs['is_last']):                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/embodied/run/train.py:76 in train_step         │
│                                                                                                  │
│    73 │   for _ in range(should_train(step)):                                                    │
│    74 │     with timer.scope('dataset'):                                                         │
│    75 │   │   batch[0] = next(dataset)                                                           │
│ ❱  76 │     outs, state[0], mets = agent.train(batch[0], state[0])                               │
│    77 │     metrics.add(mets, prefix='train')                                                    │
│    78 │     if 'priority' in outs:                                                               │
│    79 │   │   replay.prioritize(outs['key'], outs['priority'])                                   │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:79 in inner                                                    │
│                                                                                                  │
│    76 │   │   @wraps(func)                                                                       │
│    77 │   │   def inner(*args, **kwds):                                                          │
│    78 │   │   │   with self._recreate_cm():                                                      │
│ ❱  79 │   │   │   │   return func(*args, **kwds)                                                 │
│    80 │   │   return inner                                                                       │
│    81                                                                                            │
│    82                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/dreamerv3/jaxagent.py:84 in train                        │
│                                                                                                  │
│    81 │     self._once = False                                                                   │
│    82 │     assert jaxutils.Optimizer.PARAM_COUNTS                                               │
│    83 │     for name, count in jaxutils.Optimizer.PARAM_COUNTS.items():                          │
│ ❱  84 │   │   mets[f'params_{name}'] = float(count)                                              │
│    85 │   return outs, state, mets                                                               │
│    86                                                                                            │
│    87   def report(self, data):                                                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: float() argument must be a string or a real number, not 'NoneType'

This smells like the kind of error that would go away if a dependency or two were downgraded, so if someone could post their pip freeze output from your working dev environment, I would be very grateful.

LYK-love commented 6 months ago

This is my config. OS: Ubuntu22.04 , ISA: x86_64, python=3.9. You need to install ffmpeg as well.

absl-py==2.1.0
appdirs==1.4.4
astunparse==1.6.3
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.85
click==8.1.7
cloudpickle==1.6.0
crafter==1.8.3
decorator==5.1.1
dm-tree==0.1.8
docker-pycreds==0.4.0
flatbuffers==24.3.7
gast==0.5.4
gitdb==4.0.11
GitPython==3.1.42
google-pasta==0.2.0
grpcio==1.62.1
gym==0.19.0
h5py==3.10.0
idna==3.6
imageio==2.34.0
importlib_metadata==7.0.2
jax==0.4.25
jaxlib==0.4.25+cuda12.cudnn89
keras==3.0.5
libclang==16.0.6
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
ml-dtypes==0.3.2
namex==0.0.7
numpy==1.26.4
nvidia-cublas-cu12==12.4.2.65
nvidia-cuda-cupti-cu12==12.4.99
nvidia-cuda-nvcc-cu12==12.4.99
nvidia-cuda-nvrtc-cu12==12.4.99
nvidia-cuda-runtime-cu12==12.4.99
nvidia-cudnn-cu12==8.9.7.29
nvidia-cufft-cu12==11.2.0.44
nvidia-cusolver-cu12==11.6.0.99
nvidia-cusparse-cu12==12.3.0.142
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.99
opensimplex==0.4.5
opt-einsum==3.3.0
optax==0.2.1
packaging==24.0
pillow==10.2.0
protobuf==4.25.3
psutil==5.9.8
Pygments==2.17.2
PyYAML==6.0.1
requests==2.31.0
rich==13.7.1
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.8
scipy==1.12.0
sentry-sdk==1.42.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow-cpu==2.16.1
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-probability==0.24.0
termcolor==2.4.0
toolz==0.12.1
typing_extensions==4.10.0
urllib3==2.2.1
wandb==0.16.4
Werkzeug==3.0.1
wrapt==1.16.0
zipp==3.18.0
danijar commented 5 months ago

Just updated the code. Is this still an issue?