lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.57k stars 642 forks source link

Resuming training size mismatch #211

Open afiaka87 opened 3 years ago

afiaka87 commented 3 years ago

getting size mismatches on the entire checkpoint. This sort of thing.

        size mismatch for transformer.layers.blocks.12.g.net.fn.fn.net.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([8192, 1024]).
        size mismatch for transformer.layers.blocks.12.g.net.fn.fn.net.3.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.layers.blocks.13.f.net.fn.fn.to_qkv.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1536, 1024]).
        size mismatch for transformer.layers.blocks.13.f.net.fn.fn.to_out.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
        size mismatch for transformer.layers.blocks.13.g.net.fn.fn.net.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([8192, 1024]).
        size mismatch for transformer.layers.blocks.13.g.net.fn.fn.net.3.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.layers.blocks.14.f.net.fn.fn.to_qkv.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1536, 1024]).
        size mismatch for transformer.layers.blocks.14.f.net.fn.fn.to_out.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
        size mismatch for transformer.layers.blocks.14.g.net.fn.fn.net.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([8192, 1024]).
        size mismatch for transformer.layers.blocks.14.g.net.fn.fn.net.3.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.layers.blocks.15.f.net.fn.fn.to_qkv.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1536, 1024]).
        size mismatch for transformer.layers.blocks.15.f.net.fn.fn.to_out.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
        size mismatch for transformer.layers.blocks.15.g.net.fn.fn.net.0.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([8192, 1024]).
        size mismatch for transformer.layers.blocks.15.g.net.fn.fn.net.3.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for to_logits.1.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([50688, 1024]).
Killing subprocess 676

whenever I resume from a checkpoint. Were the checkpoint keys change recently?

afiaka87 commented 3 years ago

seems my checkpoint wasn't being saved properly. Maybe this is related to deepspeed requiring you to use their methods to load and save pytorch models?

janEbert commented 3 years ago

Yeah, we can't avoid that now if we want to support offloading and partitioning. I'll fix it.

afiaka87 commented 3 years ago

@janEbert Thanks!

janEbert commented 3 years ago

@lucidrains I'd rework checkpointing for this so we always (no matter whether distributed or not) save checkpoints resumable for training (which includes the optimizer state) instead of just checkpoints for inference which is the current behavior.

Is that fine with you or would you rather like to keep the old behavior? I could work around it but it would be less clean.

EDIT: Actually nevermind, it's not as cleanly solvable as I thought. Would still suggest you think about whether you'd like to save/restore the optimizer state, though!

janEbert commented 3 years ago

For now, you can use the janEbert/deepspeed branch which has temporary fixes for partitioned models. (The calls may still fail if the VAE is partitioned.)

I realized some underlying issues as well, this is why I don't PR the fix. We'll need to split out the VAE from the DALLE model. You can't load a DeepSpeed checkpoint of the VAE, either, for the DALLE model, because you can't "merge" the VAE into it.

afiaka87 commented 3 years ago

@lucidrains can we get your eyeballs on this? Could use guidance on how to store optimizer state etc.

janEbert commented 3 years ago

The possible issue I'm seeing is that DeepSpeed does not handle multiple ZeRO-enabled models at once for whatever reason (e.g. both wanting to take all GPU memory, unhandled shared global state, ...). Not sure, though, I'd need to look into it. If, only if, that's the case, splitting the models wouldn't help either. I haven't figured out how to handle that case quite yet. :)

janEbert commented 3 years ago

Partly solved by #231.