cloneofsimo / lora

Using Low-rank adaptation to quickly fine-tune diffusion models.
https://arxiv.org/abs/2106.09685
Apache License 2.0
6.83k stars 473 forks source link

Inpainting not working on custom dataset but image generation works. #258

Open soumik-kanad opened 10 months ago

soumik-kanad commented 10 months ago

Can someone please help me figure out why inpainting is not working for me while basic image generation seems to be working?

I have a folder with 500 images of an identity sampled from a few videos.

I tried training with the basic lora model with the following flags and it works

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
lora_pti \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --train_text_encoder \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --scale_lr \
  --learning_rate_unet=1e-4 \
  --learning_rate_text=1e-5 \
  --learning_rate_ti=5e-4 \
  --color_jitter \
  --lr_scheduler="linear" \
  --lr_warmup_steps=0 \
  --placeholder_tokens="<s1>|<s2>" \
  --use_template="object"\
  --save_steps=100 \
  --max_train_steps_ti=1000 \
  --max_train_steps_tuning=1000 \
  --perform_inversion=True \
  --clip_ti_decay \
  --weight_decay_ti=0.000 \
  --weight_decay_lora=0.001\
  --continue_inversion \
  --continue_inversion_lr=1e-4 \
  --device="cuda:0" \
  --lora_rank=1 \

lora_scale=0.5, prompt="style of <s1><s2>"

image

But when I tried training an inpainting model for the same dataset with the default inpainting flags, it gives garbage. (I made a small change: used --use_template="object" so that --placeholder_token_at_data="<krk>|<s1><s2>" does not get rid of the custom tokens and uses the object text templates)

export MODEL_NAME="runwayml/stable-diffusion-inpainting"
lora_pti \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --train_text_encoder \
  --train_inpainting \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=2 \
  --gradient_checkpointing \
  --scale_lr \
  --learning_rate_unet=2e-4 \
  --learning_rate_text=1e-6 \
  --learning_rate_ti=5e-4 \
  --color_jitter \
  --lr_scheduler="linear" \
  --lr_warmup_steps=0 \
  --lr_scheduler_lora="constant" \
  --lr_warmup_steps_lora=100 \
  --placeholder_tokens="<s1>|<s2>" \
  --placeholder_token_at_data="<krk>|<s1><s2>" \
  --save_steps=100 \
  --max_train_steps_ti=3000 \
  --max_train_steps_tuning=3000 \
  --perform_inversion=True \
  --clip_ti_decay \
  --weight_decay_ti=0.000 \
  --weight_decay_lora=0.000 \
  --device="cuda:0" \
  --lora_rank=8 \
  --use_face_segmentation_condition \
  --lora_dropout_p=0.1 \
  --lora_scale=2.0 \
  --use_template="object" \
  --cached_latents=False

inputs: image

Without lora patching it looks fine - prompt="photo of <s2>"

image

But as soon as I patch this model it gives garbage outputs - lora_scale=0.0, prompt="photo of <s1>"

image

lora_scale=0.0 prompt="photo of <s2>"

image

lora_scale=0.5, prompt="photo of <s1><s2>"

image

lora_scale=0.5 prompt="photo of <s1>"

image

lora_scale=0.5 prompt="photo of <s2>"

image

I also tried varying the lora_scale, but that doesn't help (as seen in the variation from 0 to 0.5). I also tried different prompts and that also didn't help.

soumik-kanad commented 10 months ago

This is what I use for inference -

from diffusers import StableDiffusionInpaintPipeline
device = "cuda"
model_path = "runwayml/stable-diffusion-inpainting"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
).to(device)
lora_scale=0.5
prompt="photo of <s1><s2>"
lora_model_path = "outputs/final_lora.safetensors"
patch_pipe(
    pipe,
    lora_model_path,
    patch_text=True,
    patch_ti=True,
    patch_unet=True,
)

torch.manual_seed(0)
tune_lora_scale(pipe.unet, lora_scale)
tune_lora_scale(pipe.text_encoder, lora_scale)

image = Image.open("image2.jpg").convert("RGB").resize((512,512))
mask_image = Image.open("mask2.png").convert("RGB").resize((512,512))

#kept giving nsfw warning and black images 
def dummy(images, **kwargs):
    return images, False
pipe.safety_checker = dummy
image = pipe(prompt=prompt, 
             image = image, 
             mask_image=mask_image,
             num_inference_steps=50, 
             guidance_scale=7).images[0]

display(image)
joe-zxh commented 1 month ago

similar result, have you solve the problem?

soumik-kanad commented 1 month ago

Nope. I'm not sure how to solve this.