timothybrooks / instruct-pix2pix

Other
6.1k stars 527 forks source link

Train on SD2.1 768px #92

Closed andreemic closed 1 year ago

andreemic commented 1 year ago

Hey, how would one go about modifying the config to train on SD2.1 instead of SD1.5?

I've tried to pull in the updated ldm codebase, modify the yaml config, and take a SD2.1 checkpoint but getting some errors (see below)

Do you have any idea how to make the model definition match the checkpoints?

My checkpoint

https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.ckpt

My config

# train.yaml
model:
  base_learning_rate: 1.0e-04
  target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion
  params:
    ckpt_path: checkpoints/v2-1_768-ema-pruned.ckpt
    # ...

The errors

I get a lot of errors like these when starting training:

size mismatch for model.diffusion_model.input_blocks.2.1.x.y.z: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current 
model is torch.Size([320, 768]).

The mismatches are between a shape element being 1024 in checkpoint vs. 768 in the model definition, or something like torch.Size([640, 640]) vs. torch.Size([640, 640, 1, 1])

Exact errors

Screenshot 2023-05-07 at 12 03 12
andreemic commented 1 year ago

Update: was able to eliminate "768 != 1024" errors by changing context_dim to "1024" instead of "768".

The "[x,y] != [x,y, 1, 1]" errors persist

e.g.

size mismatch for model.diffusion_model.input_blocks.7.1.proj_in.weight: copying a param with shape torch.Size([1280, 1280]) from checkpoint, the shape in current model is 
torch.Size([1280, 1280, 1, 1]).
andreemic commented 1 year ago

Solved:

For the v2.1 UNet an optimization was added that changes the SpatialTransformer block architecture a tiny bit:

Screenshot 2023-05-07 at 12 50 05

It takes out the 1x1 convolution and that's why there were two extra weight dimensions in the model definition.

I just had to switch my config to use the new UNetModel which uses this new SpatialTransformer defined in attention_v21.py ldm.modules.diffusionmodules.openaimodel_v21.UNetModel

sidtandon2014 commented 1 year ago

Hi @andreemic: Did you change the CLIP embeddings while training. I am getting matrix multiplication error because of [320, 1024] and [768, 2464].

On debugging I found the 768 dimension is coming from CLIP

andreemic commented 1 year ago

Hi @andreemic: Did you change the CLIP embeddings while training. I am getting matrix multiplication error because of [320, 1024] and [768, 2464].

On debugging I found the 768 dimension is coming from CLIP

Hey! I don't believe I had that issue... Can you share your config file and CLI output?

sidtandon2014 commented 1 year ago

@andreemic train.yaml.txt

model:
  base_learning_rate: 1.0e-04
  target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
  params:
    ckpt_path: stable_diffusion/models/ldm/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: edited
    cond_stage_key: edit
    image_size: 32
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: hybrid
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: true
    load_ema: false

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 0 ]
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 1024
        use_checkpoint: True
        use_linear_in_transformer: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 32
    num_workers: 2
    train:
      target: edit_dataset.EditDataset
      params:
        path: data/clip-filtered-dataset
        split: train
        min_resize_res: 256
        max_resize_res: 256
        crop_res: 256
        flip_prob: 0.5
    validation:
      target: edit_dataset.EditDataset
      params:
        path: data/clip-filtered-dataset
        split: val
        min_resize_res: 256
        max_resize_res: 256
        crop_res: 256

lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 2000
        max_images: 2
        increase_log_steps: False

  trainer:
    max_epochs: 2
    benchmark: True
    accumulate_grad_batches: 4
    check_val_every_n_epoch: 4

Error:

mat1 and mat2 shapes cannot be multiplied (2464x768 and 1024x320)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/attention.py", line 168, in forward
    k = self.to_k(context)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/attention.py", line 273, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/diffusionmodules/util.py", line 136, in forward
    output_tensors = ctx.run_function(*ctx.input_tensors)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/diffusionmodules/util.py", line 121, in checkpoint
    return CheckpointFunction.apply(func, len(inputs), *args)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/attention.py", line 269, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/attention.py", line 334, in forward
    x = block(x, context=context[i])
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py", line 84, in forward
    x = layer(x, context)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py", line 797, in forward
    h = module(h, emb, context)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 1428, in forward
    out = self.diffusion_model(xc, t, context=cc)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 1000, in apply_model
    x_recon = self.model(x_noisy, t, **cond)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 1028, in p_losses
    model_output = self.apply_model(x_noisy, t, cond)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 892, in forward
    return self.p_losses(x, c, t, *args, **kwargs)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 880, in shared_step
    loss = self(x, c)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py", line 386, in validation_step
    _, loss_dict_no_ema = self.shared_step(batch)
  File "/home/jupyter/sid/GitRepo/instruct-pix2pix/main.py", line 784, in <module>
    trainer.fit(model, data)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2464x768 and 1024x320)
sidtandon2014 commented 1 year ago

This got resolved. I have used FrozenOpenCLIPEmbedder as mentioned in SD 2.1 git repo

nityanandmathur commented 3 months ago

@sidtandon2014 @andreemic Could you please share the checkpoint for IP2P - SD 2.1? Thank you.