lllyasviel / ControlNet

Let us control diffusion models!
Apache License 2.0
30.38k stars 2.73k forks source link

How to train an image as a control element? #504

Open yi0109 opened 1 year ago

yi0109 commented 1 year ago

sample data set is

source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0

# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0

return dict(jpg=target, txt=prompt, hint=source)

I tried hint source as 6 channel by concat 2images and make yaml file

model:
  target: cldm.cldm.ControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 7
        hint_channels: 6
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 7
        out_channels: 7
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

but I got error

Traceback (most recent call last):
  File "tool_add_control.py", line 49, in <module>
    model.load_state_dict(target_dict, strict=True)
  File "/home/anaconda3/envs/control/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ControlLDM:
        size mismatch for model.diffusion_model.input_blocks.0.0.weight: copying a param with shape torch.Size([320, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 7, 3, 3]).
        size mismatch for model.diffusion_model.out.2.weight: copying a param with shape torch.Size([4, 320, 3, 3]) from checkpoint, the shape in current model is torch.Size([7, 320, 3, 3]).
        size mismatch for model.diffusion_model.out.2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([7]).
        size mismatch for control_model.input_blocks.0.0.weight: copying a param with shape torch.Size([320, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 7, 3, 3]).

or I want to image instead of txt do you know how to set ?

jiqizaisikao commented 1 year ago

WechatIMG59 WechatIMG54 WechatIMG55

这是我自己模型训练的结果,类似于Adobe的生成填充,微信最新的<换影CP>小程序,根据一张照片,一个模版,进行一键换头,一键换脸型。不过需要不 小的算力进行训练,我可能没有大量资源进行训练优化了,希望可以找到朋友进行合作研究和开发

geroldmeisinger commented 1 year ago

also see https://github.com/lllyasviel/ControlNet/issues/271

hxy-123-coder commented 1 year ago

it needs to generate a new ckpt file by tool_add_control.py

yi0109 commented 1 year ago

I will try with your advices. thanks @hxy-123-coder @geroldmeisinger @jiqizaisikao