I noticed that starting a training run from a previous checkpoint doubles the size of the replay buffer and therefore also doubles the memory requirements. To use a checkpoint I simply provide a logdir which already exists as instructed in the README. As done in example.py, I initialize the replay buffer, which results in loading the data from the previous run as expected. Then, in the the training function, the checkpoint is loaded (https://github.com/danijar/dreamerv3/blob/main/dreamerv3/embodied/run/train.py#L94). This sets the correct model parameters but it also loads the replay buffer again, as 'replay' is a key in the checkpoint dict.
if logdir.exists():
cp_file = logdir / 'checkpoint.ckpt'
cp_dict = embodied.basics.unpack(cp_file.read('rb'))
if 'replay' in cp_dict:
del cp_dict['replay']
old = cp_file.parent / (cp_file.name + '.old')
cp_file.copy(old)
cp_file.write(embodied.basics.pack(cp_dict), mode='wb')
old.remove()
This omits loading the replay buffer twice (resulting in having each datapoint in there two times and memory requirements doubling each time we restart a run) and works fine for me. Still, it is a rather hacky workaround and I feel like there should some better way, maybe I also missed something?
Hey there,
I noticed that starting a training run from a previous checkpoint doubles the size of the replay buffer and therefore also doubles the memory requirements. To use a checkpoint I simply provide a logdir which already exists as instructed in the README. As done in example.py, I initialize the replay buffer, which results in loading the data from the previous run as expected. Then, in the the training function, the checkpoint is loaded (https://github.com/danijar/dreamerv3/blob/main/dreamerv3/embodied/run/train.py#L94). This sets the correct model parameters but it also loads the replay buffer again, as 'replay' is a key in the checkpoint dict.
I found that a way to avoid this is adding the following code in example.py after line 25 (https://github.com/danijar/dreamerv3/blob/main/example.py#L25):
This omits loading the replay buffer twice (resulting in having each datapoint in there two times and memory requirements doubling each time we restart a run) and works fine for me. Still, it is a rather hacky workaround and I feel like there should some better way, maybe I also missed something?
Thanks in advance for the help!