danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 219 forks source link

Checkpoint loading - memory doubles #103

Closed alinab1809 closed 5 months ago

alinab1809 commented 9 months ago

Hey there,

I noticed that starting a training run from a previous checkpoint doubles the size of the replay buffer and therefore also doubles the memory requirements. To use a checkpoint I simply provide a logdir which already exists as instructed in the README. As done in example.py, I initialize the replay buffer, which results in loading the data from the previous run as expected. Then, in the the training function, the checkpoint is loaded (https://github.com/danijar/dreamerv3/blob/main/dreamerv3/embodied/run/train.py#L94). This sets the correct model parameters but it also loads the replay buffer again, as 'replay' is a key in the checkpoint dict.

I found that a way to avoid this is adding the following code in example.py after line 25 (https://github.com/danijar/dreamerv3/blob/main/example.py#L25):

if logdir.exists():
        cp_file = logdir / 'checkpoint.ckpt'
        cp_dict = embodied.basics.unpack(cp_file.read('rb'))
        if 'replay' in cp_dict:
            del cp_dict['replay']
            old = cp_file.parent / (cp_file.name + '.old')
            cp_file.copy(old)
            cp_file.write(embodied.basics.pack(cp_dict), mode='wb')
            old.remove()

This omits loading the replay buffer twice (resulting in having each datapoint in there two times and memory requirements doubling each time we restart a run) and works fine for me. Still, it is a rather hacky workaround and I feel like there should some better way, maybe I also missed something?

Thanks in advance for the help!

danijar commented 5 months ago

This is fixed now, replay only loads on load() and not on object creation anymore.