bghira / SimpleTuner

A general fine-tuning kit geared toward diffusion models.
GNU Affero General Public License v3.0
1.72k stars 154 forks source link

Multi-caption strategy from parquet might not work #1092

Open AmericanPresidentJimmyCarter opened 3 days ago

AmericanPresidentJimmyCarter commented 3 days ago

Caption is checked for existence, but if it's a list it causes a crash.

if not caption and fallback_caption_column:
    if not caption and fallback_caption_column:
ValueError: The truth value of an array with more than one element is ambiguous.
 Use a.any() or a.all()

You should

    def _extract_captions_to_fast_list(self):
        """
        Pull the captions from the parquet table into a dict with the format {filename: caption}.

        This helps because parquet's columnar format sucks for searching.

        Returns:
            dict: A dictionary of captions.
        """
        if self.parquet_database is None:
            raise ValueError("Parquet database is not loaded.")
        filename_column = self.parquet_config.get("filename_column")
        caption_column = self.parquet_config.get("caption_column")
        fallback_caption_column = self.parquet_config.get("fallback_caption_column")
        identifier_includes_extension = self.parquet_config.get(
            "identifier_includes_extension", False
        )
        captions = {}
        for index, row in self.parquet_database.iterrows():
            if filename_column in row:
                filename = str(row[filename_column])
            else:
                filename = str(index)
            if not identifier_includes_extension:
                filename = os.path.splitext(filename)[0]

            if type(caption_column) == list:
                caption = None
                if len(caption_column) > 0:
                    caption = [row[c] for c in caption_column]
            else:
                caption = row.get(caption_column)
                if isinstance(caption, (numpy.ndarray, pd.Series)):
                    caption = [str(item) for item in caption if item is not None]

            if caption is None and fallback_caption_column:
                caption = row.get(fallback_caption_column, None)
            if caption is None or caption == "" or caption == []:
                raise ValueError(
                    f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries."
                )
            if type(caption) == bytes:
                caption = caption.decode("utf-8")
            elif type(caption) == list:
                caption = [c.strip() for c in caption if c.strip()]
            elif type(caption) == str:
                caption = caption.strip()
            captions[filename] = caption
        return captions
AmericanPresidentJimmyCarter commented 3 days ago

It crashes later on too at:

    # Check for empty strings
    if (df[caption_column] == "").sum() > 0 and not fallback_caption_column:
        raise ValueError(
            f"Parquet file {parquet_path} contains empty strings in the '{caption_column}' column."
        )
    if (df[filename_column] == "").sum() > 0:
        raise ValueError(
            f"Parquet file {parquet_path} contains empty strings in the '{filename_column}' column."
        )
AmericanPresidentJimmyCarter commented 2 days ago

Needs also in prompts.py

        if type(image_caption) == bytes:
            image_caption = image_caption.decode("utf-8")
        if type(image_caption) == str:
            image_caption = image_caption.strip()
+       if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series):
+           image_caption = [str(item).strip() for item in image_caption if item is not None]
        if prepend_instance_prompt:
            if type(image_caption) == list:
                image_caption = [instance_prompt + " " + x for x in image_caption]
            else:
                image_caption = instance_prompt + " " + image_caption
        return image_caption