lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

model uses more vram than expected #192

Open rom1504 opened 2 years ago

rom1504 commented 2 years ago

trying to run a bigger model

Number of parameters: 1862185991 results in more than 40GB of vram usage

That means 22 bytes per param, ie 5x more than float32 param, that's not normal, even accounting for adam doing x3

    "decoder": {
        "unets": [
            {
                "dim": 384,
                "cond_dim": 512,
                "image_embed_dim": 768,
                "text_embed_dim": 768,
                "cond_on_text_encodings": true,
                "channels": 3,
                "dim_mults": [1, 2, 3, 4],
                "num_resnet_blocks": 4,
                "attn_heads": 8,
                "attn_dim_head": 64,
                "sparse_attn": true,
                "memory_efficient": true,
                        "self_attn": [false, true, true, true]
            }
        ],
        "clip": {
            "make": "openai",
            "model": "ViT-L/14"
        },
        "image_sizes": [64],
        "channels": 3,
        "timesteps": 1000,
        "loss_type": "l2",
        "beta_schedule": ["cosine"],
        "learned_variance": true
    }

I think we should fix this.

I figure doing a simple script creating the model, doing one step and checking the memory usage with the pytorch profiler would be helpful

rom1504 commented 2 years ago

ok the biggest I can fit is Number of parameters: 1639607751

            {
                "dim": 352,
                "cond_dim": 512,
                "image_embed_dim": 768,
                "text_embed_dim": 768,
                "cond_on_text_encodings": true,
                "channels": 3,
                "dim_mults": [1, 2, 3, 4],
                "num_resnet_blocks": 4,
                "attn_heads": 8,
                "attn_dim_head": 64,
                "sparse_attn": true,
                "memory_efficient": true,
                        "self_attn": [false, true, true, true]
            }
rom1504 commented 2 years ago

let's compare with https://github.com/lucidrains/DALLE2-pytorch/issues/27#issuecomment-1179367413

rom1504 commented 2 years ago

looks like they use dropout, maybe we should too

rom1504 commented 2 years ago

nevermind I had batch size at 20

when putting it at 1 I can find Number of parameters: 2297405111

    "decoder": {
        "unets": [
            {
                "dim": 440,
                "cond_dim": 512,
                "image_embed_dim": 768,
                "text_embed_dim": 768,
                "cond_on_text_encodings": true,
                "channels": 3,
                "dim_mults": [1, 2, 3, 4],
                "num_resnet_blocks": 4,
                "attn_heads": 8,
                "attn_dim_head": 64,
                "sparse_attn": true,
                "memory_efficient": true,
                        "self_attn": [false, true, true, true]
            }
        ],
        "clip": {
            "make": "openai",
            "model": "ViT-L/14"
        },
        "image_sizes": [64],
        "channels": 3,
        "timesteps": 1000,
        "loss_type": "l2",
        "beta_schedule": ["cosine"],
        "learned_variance": true
    },

that means 17 bytes per param, hence 4x a float32 param. Seems much more reasonable