danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.36k stars 229 forks source link

Calling JAXAgent train gets stuck if using larger image sizes (inside Ninjax) #85

Open schneimo opened 1 year ago

schneimo commented 1 year ago

Hi Danijar,

I am currently trying to use higher image resolutions like 256x256 for Dreamer. By simply changing the resolution e.g. for DM control suite, JAX is not able to trace/compile the training function anymore:

python dreamerv3/train.py --logdir logs/test --configs dmc_vision --task dmc_cartpole_swingup --env.dmc.size 256 256

But instead of an error the program seems to be stuck at/after the point where it tries to trace the training function with JAX:

Config:
seed:                                          0                                           (int)
method:                                        name                                        (str)
task:                                          dmc_cartpole_swingup                        (str)
logdir:                                        logs/test                                   (str)
replay:                                        reverb                                      (str)
replay_size:                                   1000000.0                                   (float)
replay_online:                                 False                                       (bool)
replay_save:                                   False                                       (bool)
eval_dir:                                                                                  (str)
filter:                                        .*                                          (str)
jax.platform:                                  gpu                                         (str)
jax.jit:                                       True                                        (bool)
jax.precision:                                 float16                                     (str)
jax.prealloc:                                  True                                        (bool)
jax.debug_nans:                                False                                       (bool)
jax.logical_cpus:                              0                                           (int)
jax.debug:                                     True                                        (bool)
jax.policy_devices:                            [0]                                         (ints)
jax.train_devices:                             [0]                                         (ints)
jax.metrics_every:                             10                                          (int)
run.script:                                    train_eval                                  (str)
run.steps:                                     1250000.0                                   (float)
run.expl_until:                                0                                           (int)
run.log_every:                                 300                                         (int)
run.save_every:                                900                                         (int)
run.eval_every:                                50000.0                                     (float)
run.eval_initial:                              True                                        (bool)
run.eval_eps:                                  10                                          (int)
run.eval_samples:                              1                                           (int)
run.train_ratio:                               512.0                                       (float)
run.train_fill:                                0                                           (int)
run.eval_fill:                                 0                                           (int)
run.log_zeros:                                 False                                       (bool)
run.log_keys_video:                            [image]                                     (strs)
run.log_keys_sum:                              ^$                                          (str)
run.log_keys_mean:                             (log_entropy)                               (str)
run.log_keys_max:                              ^$                                          (str)
run.from_checkpoint:                                                                       (str)
run.sync_every:                                10                                          (int)
run.actor_addr:                                ipc:///tmp/5551                             (str)
run.actor_batch:                               32                                          (int)
envs.amount:                                   8                                           (int)
envs.parallel:                                 process                                     (str)
envs.length:                                   0                                           (int)
envs.reset:                                    True                                        (bool)
envs.restart:                                  True                                        (bool)
envs.discretize:                               0                                           (int)
envs.checks:                                   False                                       (bool)
envs.is_vec:                                   False                                       (bool)
wrapper.length:                                0                                           (int)
wrapper.reset:                                 True                                        (bool)
wrapper.discretize:                            0                                           (int)
wrapper.checks:                                False                                       (bool)
env.atari.size:                                [64, 64]                                    (ints)
env.atari.repeat:                              4                                           (int)
env.atari.sticky:                              True                                        (bool)
env.atari.gray:                                False                                       (bool)
env.atari.actions:                             all                                         (str)
env.atari.lives:                               unused                                      (str)
env.atari.noops:                               0                                           (int)
env.atari.resize:                              opencv                                      (str)
env.dmlab.size:                                [64, 64]                                    (ints)
env.dmlab.repeat:                              4                                           (int)
env.dmlab.episodic:                            True                                        (bool)
env.minecraft.size:                            [64, 64]                                    (ints)
env.minecraft.break_speed:                     100.0                                       (float)
env.dmc.size:                                  [256, 256]                                  (ints)
env.dmc.repeat:                                2                                           (int)
env.dmc.camera:                                -1                                          (int)
env.loconav.size:                              [64, 64]                                    (ints)
env.loconav.repeat:                            2                                           (int)
env.loconav.camera:                            -1                                          (int)
task_behavior:                                 Greedy                                      (str)
expl_behavior:                                 None                                        (str)
batch_size:                                    16                                          (int)
batch_length:                                  64                                          (int)
data_loaders:                                  8                                           (int)
grad_heads:                                    [decoder, reward, cont]                     (strs)
rssm.deter:                                    512                                         (int)
rssm.units:                                    512                                         (int)
rssm.stoch:                                    32                                          (int)
rssm.classes:                                  32                                          (int)
rssm.act:                                      silu                                        (str)
rssm.norm:                                     layer                                       (str)
rssm.initial:                                  learned                                     (str)
rssm.unimix:                                   0.01                                        (float)
rssm.unroll:                                   False                                       (bool)
rssm.action_clip:                              1.0                                         (float)
rssm.winit:                                    normal                                      (str)
rssm.fan:                                      avg                                         (str)
encoder.mlp_keys:                              $^                                          (str)
encoder.cnn_keys:                              image                                       (str)
encoder.act:                                   silu                                        (str)
encoder.norm:                                  layer                                       (str)
encoder.mlp_layers:                            5                                           (int)
encoder.mlp_units:                             1024                                        (int)
encoder.cnn:                                   resnet                                      (str)
encoder.cnn_depth:                             32                                          (int)
encoder.cnn_blocks:                            0                                           (int)
encoder.resize:                                stride                                      (str)
encoder.winit:                                 normal                                      (str)
encoder.fan:                                   avg                                         (str)
encoder.symlog_inputs:                         True                                        (bool)
encoder.minres:                                4                                           (int)
decoder.mlp_keys:                              $^                                          (str)
decoder.cnn_keys:                              image                                       (str)
decoder.act:                                   silu                                        (str)
decoder.norm:                                  layer                                       (str)
decoder.mlp_layers:                            5                                           (int)
decoder.mlp_units:                             1024                                        (int)
decoder.cnn:                                   resnet                                      (str)
decoder.cnn_depth:                             32                                          (int)
decoder.cnn_blocks:                            0                                           (int)
decoder.image_dist:                            mse                                         (str)
decoder.vector_dist:                           symlog_mse                                  (str)
decoder.inputs:                                [deter, stoch]                              (strs)
decoder.resize:                                stride                                      (str)
decoder.winit:                                 normal                                      (str)
decoder.fan:                                   avg                                         (str)
decoder.outscale:                              1.0                                         (float)
decoder.minres:                                4                                           (int)
decoder.cnn_sigmoid:                           False                                       (bool)
reward_head.layers:                            2                                           (int)
reward_head.units:                             512                                         (int)
reward_head.act:                               silu                                        (str)
reward_head.norm:                              layer                                       (str)
reward_head.dist:                              symlog_disc                                 (str)
reward_head.outscale:                          0.0                                         (float)
reward_head.outnorm:                           False                                       (bool)
reward_head.inputs:                            [deter, stoch]                              (strs)
reward_head.winit:                             normal                                      (str)
reward_head.fan:                               avg                                         (str)
reward_head.bins:                              255                                         (int)
cont_head.layers:                              2                                           (int)
cont_head.units:                               512                                         (int)
cont_head.act:                                 silu                                        (str)
cont_head.norm:                                layer                                       (str)
cont_head.dist:                                binary                                      (str)
cont_head.outscale:                            1.0                                         (float)
cont_head.outnorm:                             False                                       (bool)
cont_head.inputs:                              [deter, stoch]                              (strs)
cont_head.winit:                               normal                                      (str)
cont_head.fan:                                 avg                                         (str)
loss_scales.image:                             1.0                                         (float)
loss_scales.vector:                            1.0                                         (float)
loss_scales.reward:                            1.0                                         (float)
loss_scales.cont:                              1.0                                         (float)
loss_scales.dyn:                               0.5                                         (float)
loss_scales.rep:                               0.1                                         (float)
loss_scales.actor:                             1.0                                         (float)
loss_scales.critic:                            1.0                                         (float)
loss_scales.slowreg:                           1.0                                         (float)
dyn_loss.impl:                                 kl                                          (str)
dyn_loss.free:                                 1.0                                         (float)
rep_loss.impl:                                 kl                                          (str)
rep_loss.free:                                 1.0                                         (float)
model_opt.opt:                                 adam                                        (str)
model_opt.lr:                                  0.0001                                      (float)
model_opt.eps:                                 1e-08                                       (float)
model_opt.clip:                                1000.0                                      (float)
model_opt.wd:                                  0.0                                         (float)
model_opt.warmup:                              0                                           (int)
model_opt.lateclip:                            0.0                                         (float)
actor.layers:                                  2                                           (int)
actor.units:                                   512                                         (int)
actor.act:                                     silu                                        (str)
actor.norm:                                    layer                                       (str)
actor.minstd:                                  0.1                                         (float)
actor.maxstd:                                  1.0                                         (float)
actor.outscale:                                1.0                                         (float)
actor.outnorm:                                 False                                       (bool)
actor.unimix:                                  0.01                                        (float)
actor.inputs:                                  [deter, stoch]                              (strs)
actor.winit:                                   normal                                      (str)
actor.fan:                                     avg                                         (str)
actor.symlog_inputs:                           False                                       (bool)
critic.layers:                                 2                                           (int)
critic.units:                                  512                                         (int)
critic.act:                                    silu                                        (str)
critic.norm:                                   layer                                       (str)
critic.dist:                                   symlog_disc                                 (str)
critic.outscale:                               0.0                                         (float)
critic.outnorm:                                False                                       (bool)
critic.inputs:                                 [deter, stoch]                              (strs)
critic.winit:                                  normal                                      (str)
critic.fan:                                    avg                                         (str)
critic.bins:                                   255                                         (int)
critic.symlog_inputs:                          False                                       (bool)
actor_opt.opt:                                 adam                                        (str)
actor_opt.lr:                                  3e-05                                       (float)
actor_opt.eps:                                 1e-05                                       (float)
actor_opt.clip:                                100.0                                       (float)
actor_opt.wd:                                  0.0                                         (float)
actor_opt.warmup:                              0                                           (int)
actor_opt.lateclip:                            0.0                                         (float)
critic_opt.opt:                                adam                                        (str)
critic_opt.lr:                                 3e-05                                       (float)
critic_opt.eps:                                1e-05                                       (float)
critic_opt.clip:                               100.0                                       (float)
critic_opt.wd:                                 0.0                                         (float)
critic_opt.warmup:                             0                                           (int)
critic_opt.lateclip:                           0.0                                         (float)
actor_dist_disc:                               onehot                                      (str)
actor_dist_cont:                               normal                                      (str)
actor_grad_disc:                               reinforce                                   (str)
actor_grad_cont:                               backprop                                    (str)
critic_type:                                   vfunction                                   (str)
imag_horizon:                                  15                                          (int)
imag_unroll:                                   False                                       (bool)
horizon:                                       333                                         (int)
return_lambda:                                 0.95                                        (float)
critic_slowreg:                                logprob                                     (str)
slow_critic_update:                            1                                           (int)
slow_critic_fraction:                          0.02                                        (float)
retnorm.impl:                                  perc_ema                                    (str)
retnorm.decay:                                 0.99                                        (float)
retnorm.max:                                   1.0                                         (float)
retnorm.perclo:                                5.0                                         (float)
retnorm.perchi:                                95.0                                        (float)
actent:                                        0.0003                                      (float)
expl_rewards.extr:                             1.0                                         (float)
expl_rewards.disag:                            0.1                                         (float)
expl_opt.opt:                                  adam                                        (str)
expl_opt.lr:                                   0.0001                                      (float)
expl_opt.eps:                                  1e-05                                       (float)
expl_opt.clip:                                 100.0                                       (float)
expl_opt.wd:                                   0.0                                         (float)
expl_opt.warmup:                               0                                           (int)
disag_head.layers:                             2                                           (int)
disag_head.units:                              512                                         (int)
disag_head.act:                                silu                                        (str)
disag_head.norm:                               layer                                       (str)
disag_head.dist:                               mse                                         (str)
disag_head.outscale:                           1.0                                         (float)
disag_head.inputs:                             [deter, stoch, action]                      (strs)
disag_head.winit:                              normal                                      (str)
disag_head.fan:                                avg                                         (str)
disag_target:                                  [stoch]                                     (strs)
disag_models:                                  8                                           (int)
Encoder CNN shapes: {'image': (256, 256, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (256, 256, 3)}
Decoder MLP shapes: {}
JAX devices (1): [gpu(id=0)]
Policy devices: gpu:0
Train devices:  gpu:0
Tracing train function.
Optimizer model_opt has 61,839,491 variables.
Optimizer actor_opt has 1,051,650 variables.
Optimizer critic_opt has 1,181,439 variables.
Logdir logs/test
Observation space:
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
  image            Space(dtype=uint8, shape=(256, 256, 3), low=0, high=255)
Action space:
  reset            Space(dtype=bool, shape=(), low=False, high=True)
  action           Space(dtype=float32, shape=(1,), low=-1.0, high=1.0)
Prefill train dataset.
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmp/tmpfkaj__gy.
[reverb/cc/platform/tfrecord_checkpointer.cc:567] Loading latest checkpoint from /tmp/tmpfkaj__gy
[reverb/cc/platform/default/server.cc:71] Started replay server on port 15055
Prefill eval dataset.
Found existing checkpoint.
Loading checkpoint: logs/test/checkpoint.ckpt
[reverb/cc/client.cc:[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.                                                                                                                         
165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
Loaded checkpoint from 967 seconds ago.
Start training loop.
Starting evaluation at step 1560
Tracing policy function.
Tracing policy function.
Episode has 500 steps with return 161.2.
Episode has 500 steps with return 95.8.
Episode has 500 steps with return 66.3.
Episode has 500 steps with return 98.4.
Episode has 500 steps with return 117.0.
Episode has 500 steps with return 111.9.
Episode has 500 steps with return 102.0.
Episode has 500 steps with return 56.4.
Episode has 500 steps with return 88.9.
Episode has 500 steps with return 91.3.
Episode has 500 steps with return 117.6.
Episode has 500 steps with return 89.3.
Episode has 500 steps with return 158.3.
Episode has 500 steps with return 81.8.
Episode has 500 steps with return 43.6.
Episode has 500 steps with return 86.3.
Tracing policy function.
Tracing train function.

I have tested this on a V100 and an A100. Both with the same result. With smaller resolutions (e.g. 128x128 or 64x64) this works of course.

I tried to debug this but I am not really able to track this down inside Ninjax or Jax.

Thanks a lot for your help!

edwhu commented 1 year ago

Sometimes the trace can take a while with old GPUs, I've waited around 10 minutes for a TitanX workstation before.

You can try making the CNN smaller to see if that speeds up compilation time. You can also try incrementally increasing the resolution and check if the trace time increases.

schneimo commented 1 year ago

Thanks.

I am not sure if time and compute power is really the problem. Even after 24 hours, it did not trace on an A100. But I will test how tracing time increases with increasing image resolution and report my findings here.

schneimo commented 1 year ago

I worked a little bit more on this topic and found out that the train function of class Agent is called completely since when it is decorated with an additional timer, the timer gets executed.

Furthermore, I tracked the problem a little bit more down and it seems to arise in the try block of pure inside the Ninjax module. https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/ninjax.py#L60-L101