Closed JunMa11 closed 1 year ago
Well one thing you can do is to return custom names per data in the Dreambooth dataloader, without prior preservation.
If either of you are interested in making a LoRA comparison, I'd be happy to help out on running this
Hi @cloneofsimo ,
Thanks for your guidance very much.
When you get the chance, would it be possible for you to check the following dataset class?
I use a csv file to save the paired information of images and the corresponding description.
img_id | description |
---|---|
a.png | 'description of a' |
b.png | 'description of b' |
... | ... |
class DreamBoothTextImgDataset(Dataset):
"""
A dataset to prepare the image and corresponding text description (in a csv file) for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt_csv,
tokenizer,
size=512,
center_crop=False,
color_jitter=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt_csv = instance_prompt_csv
self._length = self.num_instance_images
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_csv = pd.read_csv(self.instance_prompt_csv)
instance_image = Image.open(os.path.join(self.instance_data_root, instance_csv['img_id'][index]))
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
instance_prompt = instance_csv['description'][index]
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
The training script is
accelerate launch train_lora_text_img.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--instance_prompt_csv=$TEXTIMG_CSV\
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=30000
Hi @cloneofsimo ,
Thanks for your guidance very much.
When you get the chance, would it be possible for you to check the following dataset class?
I use a csv file to save the paired information of images and the corresponding description.
img_id description a.png 'description of a' b.png 'description of b' ... ...
class DreamBoothTextImgDataset(Dataset): """ A dataset to prepare the image and corresponding text description (in a csv file) for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__( self, instance_data_root, instance_prompt_csv, tokenizer, size=512, center_crop=False, color_jitter=False, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt_csv = instance_prompt_csv self._length = self.num_instance_images self.image_transforms = transforms.Compose( [ transforms.Resize( size, interpolation=transforms.InterpolationMode.BILINEAR ), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ColorJitter(0.2, 0.1) if color_jitter else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return self._length def __getitem__(self, index): example = {} instance_csv = pd.read_csv(self.instance_prompt_csv) instance_image = Image.open(os.path.join(self.instance_data_root, instance_csv['img_id'][index])) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") instance_prompt = instance_csv['description'][index] example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids return example
The training script is
accelerate launch train_lora_text_img.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --instance_prompt_csv=$TEXTIMG_CSV\ --output_dir=$OUTPUT_DIR \ --train_text_encoder \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=1 \ --learning_rate=1e-4 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --max_train_steps=30000
@JunMa11 How'd you go about implementing this?
Sorry for not checking in sooner. @JunMa11 , I've implmented this like such : https://github.com/cloneofsimo/lora/blob/27145c3bd02f1240ab10de6a8c00fc37c6fcadc2/lora_diffusion/dataset.py#L187
@cloneofsimo Can we use a fine-tuned model over SD1.5 for inpainting? Or an inpainting has to be finetuned a separate way?
Sorry for not checking in sooner. @JunMa11 , I've implmented this like such :
@cloneofsimo , could you please explain what's the functionality of this class ?
Hi @cloneofsimo ,
Thanks again for sharing the awesome work.
Would it be possible for you to share an example to fine-tune the model on customized datasets?
For example, the pokemon dataset https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions
Or any suggestions on how to modify this file for the customized text-image dataset (each image has its own text)
https://github.com/cloneofsimo/lora/blob/e558da067efc88811d26ec9ee47bcd20ecdd5281/run_lora_db_w_text.sh#L10