fpgaminer / bigasp-training

Various training scripts used to train bigasp
MIT License
19 stars 1 forks source link

Guide for dataset #1

Closed baizh0u closed 4 days ago

baizh0u commented 1 week ago

Thank you for writing this training code project. It's very inspiring. However, I still have a few small questions I'd like to ask. How should I organize the dataset? I see that your code uses a .parquet format dataset, what columns and fields need to be included in it? And do you mind to post a train loss curve image here as a reference?

fpgaminer commented 1 week ago

I see that your code uses a .parquet format dataset, what columns and fields need to be included in it?

Here is my build-dataset.py file, which includes the parquet schema, and of course how the parquet gets built. The training scripts use datasets, so anything compatible with load_dataset will work if it has the right columns. I just use parquet since I found it the easiest to work with.

FYI, the code published in this repo isn't intended to be a useful training pipeline, like say OneTrainer. I just dumped it raw, for whatever value it may provide as a reference to others. So it's not batteries included. For example, this script below pulls the data out of a sqlite database, which isn't documented either.

#!/usr/bin/env python3
"""
Build the parquet dataset that training will use.
This uses the data from the SQLite database. 
"""
import sqlite3
import pyarrow as pa
import pyarrow.parquet as pq
import math
from tqdm import tqdm
from pathlib import Path
import gzip
import struct
from typing import Iterable, Iterator, TypeVar
import itertools
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--database", type=str, default="data/clip-embeddings.sqlite3")
parser.add_argument("--output", type=str, default="data/dataset.parquet")
parser.add_argument("--vae-path", type=str, default="data/vaes")

schema = pa.schema([
    pa.field("index", pa.int64()),  # Index in the SQLite database
    pa.field("tag_string", pa.string(), nullable=True),
    pa.field("caption", pa.string(), nullable=True),
    pa.field("score", pa.int32()),
    pa.field("latent_width", pa.int32()),
    pa.field("latent_height", pa.int32()),
])

def main():
    args = parser.parse_args()

    conn = sqlite3.connect(args.database)
    cur = conn.cursor()

    # Fetch all records
    # Images with a score of 0 are discarded
    print("Fetching records...")
    cur.execute("SELECT id, path, tag_string, score, subreddit, caption, watermark, source, username FROM images WHERE embedding IS NOT NULL AND score IS NOT NULL AND score > 0")
    records = [{
        'index': row[0],
        'path': row[1],
        'tag_string': row[2],
        'score': row[3],
        'subreddit': row[4],
        'caption': row[5],
        'watermark': row[6],
        'source': row[7],
        'username': row[8],
    } for row in cur.fetchall()]

    # Read latent sizes
    for record in tqdm(records, desc="Reading latent sizes", dynamic_ncols=True):
        precomputed_path = Path(args.vae_path) / f"{record['index'] % 1000:03d}" / f"{record['index']}.bin.gz"
        if not precomputed_path.exists():
            print(f"Missing precomputed file for index {record['index']}, {precomputed_path} - skipping")
            continue

        with gzip.open(precomputed_path, "rb") as f:
            precomputed_index, original_width, original_height, crop_x, crop_y, latent_width, latent_height = struct.unpack("<IIIIIII", f.read(28))
            assert precomputed_index == record['index'], f"Expected index {record['index']}, got {precomputed_index}"
            record['latent_width'] = latent_width
            record['latent_height'] = latent_height

    # Remove records with missing latents
    print(f"Records before latent filter: {len(records)}")
    records = [record for record in records if 'latent_width' in record]
    print(f"Records after: {len(records)}")

    # Append subreddit, watermark, source, and username to tag_string
    for record in records:
        if record['tag_string'] is None:
            continue

        if record['subreddit'] is not None and record['subreddit'] != "":
            record['tag_string'] += f",{record['subreddit']},reddit"

        if record['watermark'] is not None and record['watermark'] == 1:
            record['tag_string'] += ",watermark"

        if record['source'] is not None and record['source'] != "":
            record['tag_string'] += f",{record['source']}"

        if record['username'] is not None and record['username'] != "":
            record['tag_string'] += f",{record['username']}"

    dataset_writer(args.output, records)

def dataset_writer(dest_path: Path | str, records: list):
    with pq.ParquetWriter(dest_path, schema) as writer:
        for batch in tqdm(batcher(records, 1000), total=math.ceil(len(records) / 1000), dynamic_ncols=True):
            indexes = [int(row['index']) for row in batch]
            tag_strings = [row['tag_string'] for row in batch]
            captions = [row['caption'] for row in batch]
            scores = [int(row['score']) for row in batch]
            latent_widths = [int(row['latent_width']) for row in batch]
            latent_heights = [int(row['latent_height']) for row in batch]

            batch = pa.RecordBatch.from_arrays([
                pa.array(indexes, type=pa.int64()),
                pa.array(tag_strings, type=pa.string()),
                pa.array(captions, type=pa.string()),
                pa.array(scores, type=pa.int32()),
                pa.array(latent_widths, type=pa.int32()),
                pa.array(latent_heights, type=pa.int32()),
            ], schema=schema)
            writer.write(batch)

T = TypeVar("T")
def batcher(iterable: Iterable[T], n: int) -> Iterator[list[T]]:
    iterator = iter(iterable)
    while batch := list(itertools.islice(iterator, n)):
        yield batch

if __name__ == "__main__":
    main()

And do you mind to post a train loss curve image here as a reference?

You can find my validation loss curve documented here: https://www.reddit.com/r/StableDiffusion/comments/1dbasvx/the_gory_details_of_finetuning_sdxl_for_30m/

And here's the training loss curve:

W B Chart 7_5_2024, 12_41_52 PM

baizh0u commented 6 days ago

Thank you so much。。。。That helps a lot!!!!! I think you have made a great project for training SDXL. And btw. do you think there is a connnection between train epochs and the scale of training data? I found that you have a 1.5M dataset for about 20 epochs, but if I just have a small dataset like 30K or 100K, should I train more epochs or 20 epochs is enough?

baizh0u commented 6 days ago

And I think it works, I tried to run this script at a 68 test dataset, but got loss NaN issuse like the picture bellow. why is that, is this because I use fp32 for encode vae or my dataset is too small ? using torchrun --nproc_per_node=NUM_GPUS train.py for training has this issuse as well. abcc

baizh0u commented 6 days ago

tried using fp16 for vae encoding, still got this issuse.....

baizh0u commented 6 days ago

here is the training setting, I am using a L20 gpu for that. 捕获

baizh0u commented 4 days ago

I found that because of the optimizer type issuse, default is adamW, I change to adam8bit, then works good. btw. Thanks for this great work.