XavierXiao / Dreambooth-Stable-Diffusion

Implementation of Dreambooth (https://arxiv.org/abs/2208.12242) with Stable Diffusion
MIT License
7.61k stars 795 forks source link

Training multiple subjects #12

Open JoePenna opened 2 years ago

JoePenna commented 2 years ago

I've seen this mentioned in some other threads here, but figured I'd start a new issue.

This is what my trained model looked like before: B63FC792-1BFF-40DB-9F8A-D82D766428EF

And after training someone else: CF3B4E97-16EA-460B-8CB5-484DCDF4C2C8

Every setting was different for the second person (class word, regularization images, etc). Yet there's certainly some sort of melding happening to the original trained token. Second token looked mostly fine (though there were some of the cfg_scale colorful artifacts).

Does anyone have any info on how to train multiple tokens at once (which seems possible, when looking deep in the code)?

Or any info on how to preserve the original trained token while adding another?

binaryninja commented 2 years ago

When training the second person did you change the name of the hardcoded token "sks"?

https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10

ExponentialML commented 2 years ago

When training the second person did you change the name of the hardcoded token "sks"?

https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10

This happens even when changing the hardcoded token. There's a problem somewhere that's pushing all embeddings to the same space during training. I thought it was the global seed, but that isn't the case.

JoePenna commented 2 years ago

@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.

ExponentialML commented 2 years ago

@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.

I haven't tried this as I'm a bit busy this week, but this could be a possible solution.

After training, prune the model checkpoints from 11GB to 2GB using this script. This can be done before or after training presumably: https://github.com/harubaru/waifu-diffusion/blob/main/scripts/prune.py

Then, merge all of the model checkpoints as stated in this repository: https://github.com/Jack000/glid-3-xl-stable#trainingfine-tuning

I don't know what effect it would have, but this might be a step in the right direction if there aren't any viable fixes to the problem.

EDIT:

Tried and doesn't work. I've tried concatenating the tensors, but no dice.

ExponentialML commented 2 years ago

Okay, so I came up with a solution to train multiple subjects into one model. I had an idea where you assign a placeholder token to a specific image, and it only trains that specific token against said image set. You no longer need to change the hard coded token with this method as it's in the filename.

For ease of use, you append a placeholder to the beginning of the filename, then parse it during training. Ideally you would want to create a dict or a custom dataloader, but this solution works.

An example is: dog_1.png , dog_2.png, cat_1.png for each different instance of thing you want to train in your training folder.

This all works in the personalized.py script. The token before _ in the filename gets parsed, then trains each image and parsed filename token against the class that you have set in the training parameter (eg: "toy").

This is setup for just the placeholder token, but can be applied to classes using this same method, and you can have multiple classes("animal_", "car_", "person_") and tokens in one go. I haven't tested this yet, but there's no reason why it shouldn't work if you wish to implement it.

You also need to increase the epochs depending on how many images you have in your dataset. In my tests, I just used three images per instance of token on what I wanted to train. This also increases training time, so if it takes 15 minutes and 5 images for a good finetune, a basic measure would be multiplying this based on the amount of image sets you have.

Here's the script below that gets this working. All you need to do is replace this personalized.py with yours, do the instructions above, and it should work. I've also renamed the templates to a {} instead of photo of a {} as it seems to give me good results, but feel free to change it back.

import os
import numpy as np
import PIL
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

import random

training_templates_smallest = [
    'a {} {}',
]

reg_templates_smallest = [
    'a {}',
]

class PersonalizedBase(Dataset):
    def __init__(self,
                 data_root,
                 size=None,
                 repeats=100,
                 interpolation="bicubic",
                 flip_p=0.5,
                 set="train",
                 placeholder_token="dog",
                 per_image_tokens=False,
                 center_crop=False,
                 mixing_prob=0.25,
                 coarse_class_text=None,
                 reg = False
                 ):

        self.data_root = data_root

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]

        # self._length = len(self.image_paths)
        self.num_images = len(self.image_paths)
        self._length = self.num_images 

        self.placeholder_token = placeholder_token

        self.per_image_tokens = per_image_tokens
        self.center_crop = center_crop
        self.mixing_prob = mixing_prob

        self.coarse_class_text = coarse_class_text

        if per_image_tokens:
            assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."

        if set == "train":
            self._length = self.num_images * repeats

        self.size = size
        self.interpolation = {"linear": PIL.Image.LINEAR,
                              "bilinear": PIL.Image.BILINEAR,
                              "bicubic": PIL.Image.BICUBIC,
                              "lanczos": PIL.Image.LANCZOS,
                              }[interpolation]
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
        self.reg = reg

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        pathname = Path(self.image_paths[i % self.num_images]).name
        placeholder_token = pathname.split("_")[0]

        if self.coarse_class_text:
            placeholder_string = f"{self.coarse_class_text} {placeholder_string}"

        if not self.reg:
            #You only need to use 1 template for Dreambooth, but you can try more if you wished (not recommended)
            text = random.choice(training_templates_smallest).format(placeholder_token, placeholder_string)
        else:
            text = random.choice(reg_templates_smallest).format(placeholder_string)

        example["caption"] = text

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            h, w, = img.shape[0], img.shape[1]
            img = img[(h - crop) // 2:(h + crop) // 2,
                (w - crop) // 2:(w + crop) // 2]

        image = Image.fromarray(img)
        if self.size is not None:
            image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip(image)
        image = np.array(image).astype(np.uint8)
        example["image"] = (image / 127.5 - 1.0).astype(np.float32)
        return example
eadnams22 commented 2 years ago

@ExponentialML How do you do multiple classes for regularizing these diverse classes as well? Like say I have "Woman" "Man" "Cat" "Dog" pools of regularization images, then have the new training images "Jane" "John" "Buddy" and "Ollie" to train into each of those classes.

How do I match them to their respective classes in the same model training session?