NM512 / dreamerv3-torch

Implementation of Dreamer v3 in pytorch.
MIT License
425 stars 96 forks source link

Error when training with onehot agents #49

Closed artiom-gesp closed 10 months ago

artiom-gesp commented 10 months ago

Hello, thank you for this great repo. When running: python3 dreamer.py --configs atari100k --task atari_breakout --logdir ~/logdir/atari_breakout_v3 I am getting the following error:

Logger: (10000 steps).
Simulate agent.
Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
Optimizer model_opt has 15686787 variables.
Traceback (most recent call last):
  File "/local/home/argesp/dreamerv3-torch/dreamer.py", line 365, in <module>
    main(parser.parse_args(remaining))
  File "/local/home/argesp/dreamerv3-torch/dreamer.py", line 287, in main
    agent = Dreamer(
  File "/local/home/argesp/dreamerv3-torch/dreamer.py", line 45, in __init__
    self._task_behavior = models.ImagBehavior(config, self._wm)
  File "/local/home/argesp/dreamerv3-torch/models.py", line 223, in __init__
    self.actor = networks.MLP(
  File "/local/home/argesp/dreamerv3-torch/networks.py", line 654, in __init__
    assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
AssertionError: onehot

By replacing the hardcoded "learned" parameter, the code seem to run, but I do not know if the training happens as intended.

self.actor = networks.MLP(
            feat_size,
            (config.num_actions,),
            config.actor["layers"],
            config.units,
            config.act,
            config.norm,
            config.actor["dist"],
            # "learned",
            1.0,
            config.actor["min_std"],
            config.actor["max_std"],
            absmax=1.0,
            temp=config.actor["temp"],
            unimix_ratio=config.actor["unimix_ratio"],
            outscale=config.actor["outscale"],
            name="Actor",
        )

Could you tell me if I am getting the intended behaviour with this fix?

NM512 commented 10 months ago

Hello,

Thank you for bringing this to my attention. The issue is related to the "dist" parameter in the networks.MLP constructor. Your temporary fix seems reasonable, but to ensure intended behavior, check this commit.

Thanks for your contribution!

artiom-gesp commented 10 months ago

Thank you for fixing it so fast!