cloneofsimo / lora

Using Low-rank adaptation to quickly fine-tune diffusion models.
https://arxiv.org/abs/2106.09685
Apache License 2.0
7k stars 480 forks source link

fine-tune stable diffusion model on customized text-image dataset. #27

Closed JunMa11 closed 1 year ago

JunMa11 commented 1 year ago

Hi @cloneofsimo ,

Thanks again for sharing the awesome work.

Would it be possible for you to share an example to fine-tune the model on customized datasets?

For example, the pokemon dataset https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions

Or any suggestions on how to modify this file for the customized text-image dataset (each image has its own text)

https://github.com/cloneofsimo/lora/blob/e558da067efc88811d26ec9ee47bcd20ecdd5281/run_lora_db_w_text.sh#L10

cloneofsimo commented 1 year ago

Well one thing you can do is to return custom names per data in the Dreambooth dataloader, without prior preservation.

brian6091 commented 1 year ago

If either of you are interested in making a LoRA comparison, I'd be happy to help out on running this

JunMa11 commented 1 year ago

Hi @cloneofsimo ,

Thanks for your guidance very much.

When you get the chance, would it be possible for you to check the following dataset class?

I use a csv file to save the paired information of images and the corresponding description.

img_id description
a.png 'description of a'
b.png 'description of b'
... ...
class DreamBoothTextImgDataset(Dataset):
    """
    A dataset to prepare the image and corresponding text description (in a csv file) for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt_csv,
        tokenizer,
        size=512,
        center_crop=False,
        color_jitter=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt_csv = instance_prompt_csv
        self._length = self.num_instance_images

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(
                    size, interpolation=transforms.InterpolationMode.BILINEAR
                ),
                transforms.CenterCrop(size)
                if center_crop
                else transforms.RandomCrop(size),
                transforms.ColorJitter(0.2, 0.1)
                if color_jitter
                else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_csv = pd.read_csv(self.instance_prompt_csv)
        instance_image = Image.open(os.path.join(self.instance_data_root, instance_csv['img_id'][index]))
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        instance_prompt = instance_csv['description'][index]
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        return example

The training script is

accelerate launch train_lora_text_img.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --instance_prompt_csv=$TEXTIMG_CSV\
  --output_dir=$OUTPUT_DIR \
  --train_text_encoder \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=30000
korakoe commented 1 year ago

Hi @cloneofsimo ,

Thanks for your guidance very much.

When you get the chance, would it be possible for you to check the following dataset class?

I use a csv file to save the paired information of images and the corresponding description.

img_id description a.png 'description of a' b.png 'description of b' ... ...

class DreamBoothTextImgDataset(Dataset):
    """
    A dataset to prepare the image and corresponding text description (in a csv file) for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt_csv,
        tokenizer,
        size=512,
        center_crop=False,
        color_jitter=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt_csv = instance_prompt_csv
        self._length = self.num_instance_images

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(
                    size, interpolation=transforms.InterpolationMode.BILINEAR
                ),
                transforms.CenterCrop(size)
                if center_crop
                else transforms.RandomCrop(size),
                transforms.ColorJitter(0.2, 0.1)
                if color_jitter
                else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_csv = pd.read_csv(self.instance_prompt_csv)
        instance_image = Image.open(os.path.join(self.instance_data_root, instance_csv['img_id'][index]))
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        instance_prompt = instance_csv['description'][index]
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        return example

The training script is

accelerate launch train_lora_text_img.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --instance_prompt_csv=$TEXTIMG_CSV\
  --output_dir=$OUTPUT_DIR \
  --train_text_encoder \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=30000

@JunMa11 How'd you go about implementing this?

cloneofsimo commented 1 year ago

Sorry for not checking in sooner. @JunMa11 , I've implmented this like such : https://github.com/cloneofsimo/lora/blob/27145c3bd02f1240ab10de6a8c00fc37c6fcadc2/lora_diffusion/dataset.py#L187

aseemkhanduja commented 1 year ago

@cloneofsimo Can we use a fine-tuned model over SD1.5 for inpainting? Or an inpainting has to be finetuned a separate way?

nattaran commented 1 year ago

Sorry for not checking in sooner. @JunMa11 , I've implmented this like such :

https://github.com/cloneofsimo/lora/blob/27145c3bd02f1240ab10de6a8c00fc37c6fcadc2/lora_diffusion/dataset.py#L187

@cloneofsimo , could you please explain what's the functionality of this class ?