sayakpaul / dreambooth-keras

Implementation of DreamBooth in KerasCV and TensorFlow.
https://keras.io/examples/generative/dreambooth/
Apache License 2.0
87 stars 15 forks source link

Experimentation #11

Closed sayakpaul closed 1 year ago

sayakpaul commented 1 year ago

Without fine-tuning the text encoder (Sayak)

With fine-tuning the text encoder (Chansung)

Keep all the other hyperparameters same. Use mixed-precision training and WandB logging.

sayakpaul commented 1 year ago

In case it helps, I created a short Python script like so for starting the experiments:

import os

lrs = [1e-6, 3e-6, 5e-6]
max_train_steps = [800, 1000, 1200]

for lr in lrs:
    for steps in max_train_steps:
        print(f"Executing with lr: {lr} and max_train_steps: {steps}.")
        command = f"python train_dreambooth.py --mp --log_wandb --lr {lr} --max_train_steps {steps}"
        os.system(command)
sayakpaul commented 1 year ago

Here's my report: https://wandb.ai/sayakpaul/dreambooth-keras/reports/DreamthBooth-training-in-Keras-without-fine-tuning-the-text-encoder--VmlldzozMzc0ODI3

deep-diver commented 1 year ago

the results without the shuffling fix : https://wandb.ai/chansung18/dreambooth-keras?workspace=user-chansung18

the dash board will be updated continuously. I am running the experiments with the fix again in the background :)

sayakpaul commented 1 year ago

the dash board will be updated continuously. I am running the experiments with the fix again in the background :)

I suggest we first merge #12 and the start the experiments from there for sanity.

sayakpaul commented 1 year ago

Now that we don't have any outstanding PRs, we can start from the main and run the experiments.

deep-diver commented 1 year ago

I was training from that branch, so it should be fine I think

sayakpaul commented 1 year ago

Sure.

sayakpaul commented 1 year ago

@deep-diver FYI, I running the code from debug branch without text encoder fine-tuning. After that experiments have been run, will experiment each of the fine-tuned checkpoints alongside num_steps and unconditional_guidance_scale to see how results are affected.

sayakpaul commented 1 year ago

@deep-diver using the following script to generate images with various hyperparameters:

import tensorflow as tf

tf.keras.mixed_precision.set_global_policy("mixed_float16")

import glob
import os

import keras_cv
import numpy as np
import PIL
import wandb
from tqdm import tqdm

def download_unet_params(run_id, run_name) -> str:
    run = wandb.init(name=run_name)
    run_artifact_id = f"sayakpaul/dreambooth-keras/run_{run_id}_model:v0"
    artifact = run.use_artifact(run_artifact_id, type="model")
    artifact_dir = artifact.download()
    unet_params_path = glob.glob(f"{artifact_dir}/*.h5")[0]
    return unet_params_path

# Initialize the SD model.
img_height = img_width = 512
sd_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height, jit_compile=True
)

# Download run data.
api = wandb.Api()
runs = api.runs("sayakpaul/dreambooth-keras")

# Initialize variables.
num_steps = [25, 50, 75, 100]
num_images_to_gen = 3
caption = "A photo of sks dog in a bucket"
unconditional_guidance_scales = [7.5, 15, 30]

# Generate example results.
for run in tqdm(runs):
    run_id = run.id
    run_name = run.name

    print(f"Generating images for {run_name}.")
    unet_params_path = download_unet_params(run_id, run_name)
    sd_model.diffusion_model.load_weights(unet_params_path)
    os.makedirs(run_name, exist_ok=True)

    for steps in num_steps:
        for scale in unconditional_guidance_scales:
            images = sd_model.text_to_image(
                caption,
                batch_size=num_images_to_gen,
                num_steps=steps,
                unconditional_guidance_scale=scale,
            )

            wandb.log(
                {
                    f"num_steps@{steps}-ugs@{scale}": [
                        wandb.Image(
                            PIL.Image.fromarray(image), caption=f"{i}: {caption}"
                        )
                        for i, image in enumerate(images)
                    ]
                }
            )

See example: https://wandb.ai/sayakpaul/uncategorized/runs/2lk44tjf

sayakpaul commented 1 year ago

@deep-diver this one seems to be a good one: https://wandb.ai/sayakpaul/experimentation_images/runs/5jm2txyc (in comparison to the others).

sayakpaul commented 1 year ago

Closing as we have run a couple of experiments.