harubaru / waifu-diffusion

stable diffusion finetuned on weeb stuff
GNU Affero General Public License v3.0
1.94k stars 175 forks source link

Simpler bucketing code #59

Closed lopho closed 1 year ago

lopho commented 1 year ago
lopho commented 1 year ago

only one thing that I'd improve upon this is dropping less samples. Right now, if a bucket has less samples than batch size, it gets dropped. I might just put them in the next best bucket instead of that.

lopho commented 1 year ago

Improved migrations code. Its now 80x faster. :framed_picture::dash:

lopho commented 1 year ago

validation is now parallel too.

lopho commented 1 year ago

i think ill split this PR into another one for parallelism Its getting a bit much.

lopho commented 1 year ago

parallelism is now in a separate PR #60

lopho commented 1 year ago

There seems to be some confusion as to why this exists: Current bucketing code only works correctly if the the dataset has been resized beforehand by wd migrations using the sizes as given by the buckets. This can fail by creating buckets of unevenly sized samples if you use an already appropriately resized dataset. This bucketing code can create correct buckets for any correctly sized dataset. (or even with wrong sizes e.g. sides with non 64 multiple, useless though as that can't be used to train).

The current bucket implementation (AspectBucket) start off with trying to create evenly spaced buckets, without actually considering the data distribution. It gets complex as all operations post linear space are trying to fix the buckets so data fits using linear algebra. It is never confirmed that data landing in buckets is actually the size intended for the bucket. Buckets are hashed by the image ratio, which introduces floating errors and makes hashing non-deterministic between environments. In an attempt to fix these rounding errors, ratio floats are being truncated. This fixes the hashing problem but introduces wrong scaling ratios for some input images.

My approach (SimpleBucket) is the opposite. It first creates buckets perfectly fitting input data, then reexamines. Buckets are uniquely hashed by the exact size (w,h) of samples in the bucket using builtin hashing of dict. It considers every sample and creates buckets for each image size in the set (post scaling to optimal size: 64*side, max area etc.). Then will redistribute to the next best fit (if resizing) or drop samples if a bucket is smaller than batch size.

Optimal sizes are determined by getting a scale factor that scales the sides of the input image so it exactly fills the maximum area given:

max_area = 512*512
scale = ((w * h) / max_area)**0.5
# if you would scale h and w by this float scale, without rounding then w*h == max_area
# scale = ((w * h) / max_area)**0.5
# scale**2 = (w * h) / max_area
# max_area = (w * h) * scale**2
# max_area = (w * scale) * (h * scale)

Then round to the closest multiple of 64 (or the given side divisor) that would fill max area

w2 = round((w * scale) / 64) * 64
h2 = round((w * scale) / 64) * 64

This results in an image filling as much of the maximum are as possible, while retaining an aspect ratio as close as possible to the input image but with sides divisible by 64. There is only one edge case, where rounding up results in an image larger than max area. In this case, round down (correct in case of non-negative sizes, which is given).

if w2*h2 > max_area:
  w = int((w * scale) / 64) * 64
  h = int((h * scale) / 64) * 64

Then bucket the image using the optimal size

bucket[(w, h)] = ...

All this is only done, if the user wants to resize images (--resize=True) Otherwise, just take images sizes as they are and create buckets from those.

bucket[(image.width, image.height)] = ...

Of course this needs image sizes to have correct sides for training (divisibly by 64)

Lastly, AspectBucket does not honor the --shuffle flag, and shuffles non-discriminatly. SimpleBucket only shuffles with this flag (e.g. passed in to __init__(... shuffle=True). It shuffles both samples in a batch, as well as batch order.

lopho commented 1 year ago

Also training time performance is 2-6x better. Which is negligible (IMHO) since

where running a step of gradient descent on a batch would take 3 seconds.

lopho commented 1 year ago

Benchmark results: https://pastebin.com/w9GVkZVb Benchmark code to reproduce:

def benchmark(sampler, args, time_samples = 1000, output_file = None):
    from time import perf_counter
    print("Benchmarking bucket sampler")
    bs = args.batch_size
    batches = len(sampler)
    total = bs * batches
    tt = time_samples
    results = {
            'name': str(type(sampler)),
            'shuffle': args.shuffle,
            'time_samples' : tt,
            'num_batches': batches,
            'num_samples': total,
            'results': {}
    }
    pself = 0
    for _ in range(tt):
        pself0 = perf_counter()
        pself1 = perf_counter()
        pself += (pself1 - pself0)
    pself = 0
    for _ in range(tt):
        pself0 = perf_counter()
        pself1 = perf_counter()
        pself += (pself1 - pself0)
    pself = (pself1 - pself0) / tt
    x = []
    for _ in range(tt):
        x.append(sampler.__iter__())
    x = []
    now = perf_counter()
    for _ in range(tt):
        x.append(sampler.__iter__())
    end = perf_counter()
    took = (end - now) - pself
    print(len(x))
    results['results']['iterator'] = {
            'total': took,
            'per_epoch': took / tt
    }
    x = 0
    for _ in range(tt):
        for b in sampler:
            x += len(b)
    x = 0
    now = perf_counter()
    for _ in range(tt):
        for b in sampler:
            x += len(b)
    end = perf_counter()
    took = (end - now) - pself
    print(x)
    results['results']['batches'] = {
            'total': took,
            'per_epoch': took / tt,
            'per_batch': took / (tt * batches)
    }
    # warmup
    x = 0
    for _ in range(tt):
        for b in sampler:
            for idx, w, h in b:
                x += (idx + w + h)
    x = 0
    now = perf_counter()
    for _ in range(tt):
        for b in sampler:
            for idx, w, h in b:
                x += (idx + w + h)
    end = perf_counter()
    took = (end - now) - pself
    print(x)
    results['results']['samples'] = {
            'total': took,
            'per_epoch': took / tt,
            'per_batch': took / (tt * batches),
            'per_sample': took / (tt * total)
    }
    if output_file is not None:
        from json import dump
        with open(output_file, 'w') as f:
            dump(results, f)
    return results
lopho commented 1 year ago

idc