crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.26k stars 372 forks source link

Config for training other resolutions #9

Closed eliahuhorwitz closed 2 years ago

eliahuhorwitz commented 2 years ago

Hello and thanks for the implementation of the paper! I ran the code with the current config and it seems to do very good, how would one go about training a model with images of size 64x64 or 128x128?

Thanks, Eliahu

crowsonkb commented 2 years ago

For 64x64, something along the lines of:

    "model": {
        "type": "image_v1",
        "input_channels": 3,
        "input_size": [64, 64],
        "mapping_out": 256,
        "depths": [2, 2, 4, 4],
        "channels": [128, 256, 256, 512],
        "self_attn_depths": [false, false, true, true],
        "dropout_rate": 0.05,
        "augment_prob": 0.12,
        "sigma_data": 0.5,
        "sigma_min": 1e-2,
        "sigma_max": 80,
        "sigma_sample_density": {
            "type": "lognormal",
            "mean": -1.2,
            "std": 1.2
        }
    },

For each doubling in resolution you need to add one additional U-Net stage, so depths, channels, and self_attn_depths need to be four items long for 64x64 instead of three. The last two stages (16x16 and 8x8) should have self-attention and the others should not. For 128x128, something along the lines of:

    "model": {
        "type": "image_v1",
        "input_channels": 3,
        "input_size": [128, 128],
        "mapping_out": 256,
        "depths": [2, 2, 2, 4, 4],
        "channels": [128, 256, 256, 512, 512],
        "self_attn_depths": [false, false, false, true, true],
        "dropout_rate": 0.05,
        "augment_prob": 0.12,
        "sigma_data": 0.5,
        "sigma_min": 1e-2,
        "sigma_max": 160,
        "sigma_sample_density": {
            "type": "lognormal",
            "mean": -1.2,
            "std": 1.2
        }
    },

As you increase resolution you also need to increase sigma_max, 160 is more usual for higher resolution models.

eliahuhorwitz commented 2 years ago

Thanks for the quick reply! In the case of the 128 variant, I will probably need a smaller batch size, any idea how should the rest of the optimization related hyper-parameters change? (e.g. should the lr be linearly scaled as is often done in representation learning? should other hyper-parameters chage?)

crowsonkb commented 2 years ago

lr should probably go down roughly linearly if you decrease the batch size. :)

eliahuhorwitz commented 2 years ago

Thanks!