kohya-ss / sd-scripts

Apache License 2.0
4.92k stars 823 forks source link

Collecting image sizes is very slow. #1493

Open markrmiller opened 3 weeks ago

markrmiller commented 3 weeks ago

When I train with 4000 images, it takes forever to start training because just collecting the image sizes from the npz filenames takes 20 minutes from the drive where I have the data.

This code takes...I didn't measure, but maybe 10 seconds?

import os
import re
import glob
from typing import List, Tuple, Optional, Dict
from tqdm import tqdm
import multiprocessing as mp
from functools import partial

# Compile the regex pattern once
size_pattern = re.compile(r'_(\d+)x(\d+)(?:_flux\.npz|\.npz)$')
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"

def parse_size_from_filename(filename: str) -> Tuple[Optional[int], Optional[int]]:
    match = size_pattern.search(filename)
    if match:
        return int(match.group(1)), int(match.group(2))
    logger.warning(f"Failed to parse size from filename: {filename}")
    return None, None

def get_all_cache_files(img_paths: List[str]) -> Dict[str, str]:
    cache_files = {}
    base_dirs = set(os.path.dirname(path) for path in img_paths)

    for base_dir in base_dirs:
        for file in os.listdir(base_dir):
            if file.endswith(FLUX_LATENTS_NPZ_SUFFIX):
                # Remove the size and suffix to create the key
                base_name = re.sub(r'_\d+x\d+_flux\.npz$', '', file)
                cache_files[os.path.join(base_dir, base_name)] = file

    return cache_files

def process_batch(batch: List[str], cache_files: Dict[str, str]) -> List[Tuple[Optional[int], Optional[int]]]:
    results = []
    for img_path in batch:
        base_path = os.path.splitext(img_path)[0]
        if base_path in cache_files:
            results.append(parse_size_from_filename(cache_files[base_path]))
        else:
            #results.append((None, None))
            raise FileNotFoundError(f"Cache file not found for {img_path}")
    return results

def get_image_sizes_from_cache_files(img_paths: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
    cache_files = get_all_cache_files(img_paths)

    num_cores = mp.cpu_count()
    batch_size = max(1, len(img_paths) // (num_cores * 4))  # Adjust batch size for better load balancing
    batches = [img_paths[i:i + batch_size] for i in range(0, len(img_paths), batch_size)]

    with mp.Pool(num_cores) as pool:
        process_func = partial(process_batch, cache_files=cache_files)
        results = list(tqdm(
            pool.imap(process_func, batches),
            total=len(batches),
            desc="Processing image batches"
        ))

    # Flatten the results
    return [size for batch in results for size in batch]
kohya-ss commented 3 weeks ago

Thanks for reporting. In my environment, it takes about 10 seconds for 3,000 images with the current implementation of sd-scripts. Do you know what causes the difference, between 20 minutes and 10 seconds?

KujoAI commented 3 weeks ago

do you have a fast ssd?

kohya-ss commented 3 weeks ago

The data is on the HDD.

                    INFO     get image size from name of cache files                                  train_util.py:1745
100%|███████████████████████████████████████████████████████████████████████████████| 906/906 [00:03<00:00, 254.30it/s]
2024-08-22 12:11:45 INFO     set image size from cache files: 906/906                                 train_util.py:1752
                    INFO     found directory xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx       train_util.py:1754
                             contains 906 image files
2024-08-22 12:11:48 INFO     get image size from name of cache files                                  train_util.py:1745
100%|██████████████████████████████████████████████████████████████████████████████| 136/136 [00:00<00:00, 1606.09it/s]
                    INFO     set image size from cache files: 136/136                                 train_util.py:1752
                    INFO     found directory xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx      train_util.py:1754
                             contains 136 image files
                    INFO     get image size from name of cache files                                  train_util.py:1745
100%|███████████████████████████████████████████████████████████████████████████████| 180/180 [00:00<00:00, 767.91it/s]
                    INFO     set image size from cache files: 180/180                                 train_util.py:1752
                    INFO     found directory xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx      train_util.py:1754
                             contains 180 image files
2024-08-22 12:11:49 INFO     get image size from name of cache files                                  train_util.py:1745
100%|███████████████████████████████████████████████████████████████████████████████| 180/180 [00:00<00:00, 822.88it/s]

and 4 more directories.

kohya-ss commented 3 weeks ago

Thanks for reporting. It's very strange, but the following line may take an extremely long time on certain systems:

        npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX)

I will try changing the processing to avoid using globs as much as possible.

kohya-ss commented 3 weeks ago

I added a new branch fast_image_sizes. Could you please test with the branch?

RefractAI commented 3 weeks ago

The new branch fixed this issue for me.