bghira / SimpleTuner

A general fine-tuning kit geared toward diffusion models.
GNU Affero General Public License v3.0
1.68k stars 149 forks source link

Some caption type errors and speed issue #283

Closed janzd closed 9 months ago

janzd commented 9 months ago

Thanks for the recent updates. I found some more issues with types of the caption variable in the code (sometimes it's bytes while the code expects a string).

Besides that, I have some issues with speed.

The caption is in bytes format but L166 in sdxl_embeds.py expects a string. I added an if-condition to check the type and convert it to string if it's in bytes.

[2024-01-23 08:12:19,444] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-01-23 08:12:23,296] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2024-01-23 08:12:24,099 [INFO] (ArgsParser) Default VAE Cache location: /data/src/SimpleTuner/ckpt/<PROJECT NAME>/cache_vae
2024-01-23 08:12:24,099 [INFO] (ArgsParser) Text Cache location: cache
2024-01-23 08:12:24,101 [INFO] (__main__) Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

2024-01-23 08:12:24,103 [INFO] (__main__) Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32.
2024-01-23 08:12:24,103 [INFO] (__main__) Load tokenizers
2024-01-23 08:12:25,446 [INFO] (__main__) Load text encoder 1..
2024-01-23 08:12:26,613 [INFO] (__main__) Load text encoder 2..
2024-01-23 08:12:31,502 [INFO] (__main__) Load VAE..
2024-01-23 08:12:31,787 [INFO] (__main__) Moving models to GPU. Almost there.
2024-01-23 08:12:32,712 [INFO] (__main__) Creating the U-net..
2024-01-23 08:12:33,938 [INFO] (__main__) Moving the U-net to GPU.
2024-01-23 08:12:36,889 [INFO] (__main__) Enabling xformers memory-efficient attention.
2024-01-23 08:12:37,142 [INFO] (__main__) Initialising VAE in bf16 precision, you may specify a different value if preferred: bf16, fp16, fp32, default
2024-01-23 08:12:37,244 [INFO] (__main__) Loaded VAE into VRAM.
2024-01-23 08:12:37,272 [INFO] (DataBackendFactory) Configuring text embed backend: <TEXT EMBEDS>
2024-01-23 08:12:37,273 [INFO] (TextEmbeddingCache) (id=<TEXT EMBEDS>) Listing all text embed cache entries
2024-01-23 08:12:38,938 [INFO] (DataBackendFactory) Pre-computing null embedding for caption dropout
2024-01-23 08:12:38,940 [INFO] (DataBackendFactory) Completed loading text embed services.                                   
2024-01-23 08:12:38,940 [INFO] (DataBackendFactory) Configuring data backend: <ID>
2024-01-23 08:12:38,941 [INFO] (DataBackendFactory) (id=<ID>) Loading bucket manager.
2024-01-23 08:12:38,952 [INFO] (DataBackendFactory) (id=<ID>) Refreshing aspect buckets.
2024-01-23 08:12:52,135 [INFO] (DataBackendFactory) (id=<ID>) Reloading bucket manager cache.
(Rank: 0)  | Bucket     | Image Count 
------------------------------
(Rank: 0)  | 1.0        | 12504       
2024-01-23 08:12:52,193 [INFO] (DataBackendFactory) (id=<ID>) Initialise text embed pre-computation. We have 12584 captions to process.
2024-01-23 08:12:52,433 [INFO] (DataBackendFactory) (id=<ID>) Completed processing 12584 captions.
2024-01-23 08:12:52,433 [INFO] (DataBackendFactory) (id=<ID>) Pre-computing VAE latent space.
2024-01-23 08:12:52,734 [INFO] (DataBackendFactory) Skipping error scan for dataset <ID>. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
2024-01-23 08:12:52,879 [INFO] (validation) Precomputing the negative prompt embed for validations.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-23 08:12:53,346 [INFO] (__main__) Moving text encoders back to CPU, to save VRAM. Currently, we cannot completely unload the text encoder.
2024-01-23 08:12:53,713 [INFO] (__main__) After nuking text encoders from orbit, we freed 1.53 GB of VRAM. The real memories were the friends we trained a model on along the way.
2024-01-23 08:12:53,713 [INFO] (__main__) Collected the following data backends: ['<ID>']
2024-01-23 08:12:53,713 [INFO] (__main__) Loading cosine learning rate scheduler with 100000 warmup steps
2024-01-23 08:12:53,718 [INFO] (__main__) Learning rate: 8e-07
2024-01-23 08:12:53,719 [INFO] (__main__) Using 8bit AdamW optimizer.
2024-01-23 08:12:53,719 [INFO] (__main__) Optimizer arguments, weight_decay=0.01 eps=1e-08
2024-01-23 08:12:53,726 [INFO] (__main__) Loading our accelerator...
2024-01-23 08:12:53,742 [INFO] (__main__) After removing any undesired samples and updating cache entries, we have settled on 1920 epochs and 521 steps per epoch.
2024-01-23 08:12:53,854 [INFO] (__main__) After the VAE from orbit, we freed 0.0 MB of VRAM.
2024-01-23 08:12:53,873 [INFO] (__main__) ***** Running training *****
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Num batches = 2084
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Num Epochs = 1920
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Current Epoch = 1
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Instantaneous batch size per device = 6
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Gradient Accumulation steps = 4
2024-01-23 08:12:53,874 [INFO] (__main__)    -> Total train batch size (w. parallel, distributed & accumulation) = 24
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Total optimization steps = 1000000
2024-01-23 08:12:53,874 [INFO] (__main__)  -> Total optimization steps remaining = 1000000
Epoch 1/1920 Steps:   0%|                                                                        | 0/1000000 [00:00<?, ?it/s]2024-01-23 08:12:53,913 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a young boy. He has black hair, brown eyes, and is wearing a black hat. He is dressed in a blue and white jacket, a red tie, and black pants. He is also carrying a backpack.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/902944e8141052ab8f82b75a4bca7add-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/902944e8141052ab8f82b75a4bca7add-sdxl.pt not found.
2024-01-23 08:12:53,913 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long blonde hair. Her eyes are blue. She is wearing a yellow bow in her hair. She is dressed in a yellow and white outfit. She is wearing a white shirt and a yellow bow. She is also wearing a white skirt. She has a pink bag and a pink teddy bear. She is wearing shoes.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/c3523270ad89e20d5f6189a1bca700e9-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/c3523270ad89e20d5f6189a1bca700e9-sdxl.pt not found.
2024-01-23 08:12:53,914 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: 
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
2024-01-23 08:12:53,914 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long, blonde hair. Her eyes are blue. She is wearing a white and black jacket, a white hat, and a white and black skirt. She is also wearing white shoes.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/e7387e22acbfca01cc18389574e3a842-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/e7387e22acbfca01cc18389574e3a842-sdxl.pt not found.
2024-01-23 08:12:53,914 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long white hair and blue eyes. She is wearing a blue dress and holding a white teddy bear. She is also wearing a necklace.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/3fd7f7d7bf892b271f6d99e21ea7ae81-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/3fd7f7d7bf892b271f6d99e21ea7ae81-sdxl.pt not found.
2024-01-23 08:12:53,914 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a female. She has long, white hair. Her eyes are blue. She is wearing a black jacket, a black hat, and a white bow. She is also wearing a necklace.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/cef52adf6664da6943bad4edd8259db8-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/cef52adf6664da6943bad4edd8259db8-sdxl.pt not found.
2024-01-23 08:12:53,916 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a young boy. He has black hair, brown eyes, and is wearing a black hat. He is dressed in a blue and white jacket, a red tie, and black pants. He is also carrying a backpack.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 168, in encode_sdxl_prompt
    return positive_prompt
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

2024-01-23 08:12:53,916 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a girl. She has long blonde hair. Her eyes are blue. She is wearing a yellow bow in her hair. She is dressed in a yellow and white outfit. She is wearing a white shirt and a yellow bow. She is also wearing a white skirt. She has a pink bag and a pink teddy bear. She is wearing shoes.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 168, in encode_sdxl_prompt
    return positive_prompt
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

2024-01-23 08:12:53,917 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a female. She has long, white hair. Her eyes are blue. She is wearing a black jacket, a black hat, and a white bow. She is also wearing a necklace.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 168, in encode_sdxl_prompt
    return positive_prompt
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

2024-01-23 08:12:53,917 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a girl. She has long, blonde hair. Her eyes are blue. She is wearing a white and black jacket, a white hat, and a white and black skirt. She is also wearing white shoes.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 168, in encode_sdxl_prompt
    return positive_prompt
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

2024-01-23 08:12:53,917 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a girl. She has long white hair and blue eyes. She is wearing a blue dress and holding a white teddy bear. She is also wearing a necklace.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 168, in encode_sdxl_prompt
    return positive_prompt
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.

I was able to run the training then, but for some reason, it is extremely slow. It takes about 150 seconds per iteration. I'm running it on A100, so it should definitely be faster.

Epoch 1/1920, Steps:   0%|                        | 91/1000000 [3:57:37<43151:47:50, 155.36s/it, lr=7.99e-7, step_loss=0.149]2024-01-22 14:36:46,625 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long black hair. Her eyes are brown. She is wearing a black jacket, a white shirt, and a red tie. She is also wearing a black hat.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/e40409501e542a58dcc09434a8c3d83f-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/e40409501e542a58dcc09434a8c3d83f-sdxl.pt not found.
2024-01-22 14:36:46,625 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long, blonde hair. Her eyes are blue. She is wearing a black and white hat, a black and white jacket, and a pink dress. She is also wearing black shoes and white socks.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/de596a67f66d954b002989615dc80689-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/de596a67f66d954b002989615dc80689-sdxl.pt not found.
2024-01-22 14:36:46,626 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has short hair. Her eyes are blue. She is wearing a blue hoodie and a black hat. She is holding a lighted lantern.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/1caa570630b83f7cb5bf100fcb4e962d-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/1caa570630b83f7cb5bf100fcb4e962d-sdxl.pt not found.
2024-01-22 14:36:46,626 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: 
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
2024-01-22 14:36:46,627 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has blonde hair, brown eyes, and is wearing a black hat. She is dressed in black clothing, including a black jacket, and is holding a stuffed animal.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/4f2c31189bac80f032dec48ca2a24200-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/4f2c31189bac80f032dec48ca2a24200-sdxl.pt not found.
2024-01-22 14:36:46,627 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: 
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
Epoch 1/1920, Steps:   0%|                       | 92/1000000 [3:58:16<43154:54:50, 155.37s/it, lr=7.99e-7, step_loss=0.0704]2024-01-22 14:37:25,516 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has brown hair, blue eyes, and is wearing a blue dress. She is also wearing a crown and a necklace.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/436519899aeb78024e9ee7594f64294f-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/436519899aeb78024e9ee7594f64294f-sdxl.pt not found.
2024-01-22 14:37:25,517 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a boy. His hair is black and long. He has purple eyes. He is wearing a green jacket and a black and white shirt. He is also wearing a black mask and a black hat.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/5688507b91dd2375cc1858deb2dcffe7-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/5688507b91dd2375cc1858deb2dcffe7-sdxl.pt not found.
2024-01-22 14:37:25,517 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has black hair, blue eyes, and is wearing a white shirt and a white and gray coat. She is also wearing a white hat.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/fff0e1510b5633c7170ae3717df495ee-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/fff0e1510b5633c7170ae3717df495ee-sdxl.pt not found.
2024-01-22 14:37:25,517 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has brown hair, red eyes, and is wearing a red jacket. She is holding a skateboard.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/b8d7cb52781c14cbf74312457fd078a0-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/b8d7cb52781c14cbf74312457fd078a0-sdxl.pt not found.
2024-01-22 14:37:25,518 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'The character is a girl. She has long black hair. Her eyes are blue. She is wearing a white jacket, a black hat, and black sunglasses. She is also wearing black shoes.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/f0ab3166616f7f767b32dd4c4a17eddd-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/f0ab3166616f7f767b32dd4c4a17eddd-sdxl.pt not found.
2024-01-22 14:37:25,519 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: b'Female, blonde hair, blue eyes, wearing glasses, a white sweater, blue shorts, and holding a bouquet of flowers.'
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/8f9f2c9b81d3cc4a99104b7e2dc9c3a6-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/8f9f2c9b81d3cc4a99104b7e2dc9c3a6-sdxl.pt not found.

My training data directory contains 12586 images and 12586 corresponding text files with captions. After running the script, there's 12504 .pt files in ./data/<DATA_DIR>/vaecache (I suppose some images were skipped for some reason). There are also directories ./ckpt/<PROJECT NAME>/cache and ./ckpt/<PROJECT NAME>/cache_vae, but those are empty.

The logging output says

2024-01-23 06:32:28,568 [INFO] (ArgsParser) Default VAE Cache location: /data/src/SimpleTuner/ckpt/<PROJECT NAME>/cache_vae
2024-01-23 06:32:28,568 [INFO] (ArgsParser) Text Cache location: cache

but the cache is stored elsewhere, so I wonder if that's an issue.

Just for completion, here's my argument list

snr_gamma:       5.0
model_type:      full
rank:    4
pretrained_model_name_or_path:   stabilityai/stable-diffusion-xl-base-1.0
pretrained_vae_model_name_or_path:       madebyollin/sdxl-vae-fp16-fix
prediction_type:         epsilon
snr_weight:      1.0
training_scheduler_timestep_spacing:     trailing
inference_scheduler_timestep_spacing:    trailing
timestep_bias_strategy:          none
timestep_bias_multiplier:        1.0
timestep_bias_begin:     0
timestep_bias_end:       1000
timestep_bias_portion:   0.25
rescale_betas_zero_snr:          False
vae_dtype:       bf16
vae_batch_size:          4
vae_cache_behaviour:     recreate
keep_vae_loaded:         False
skip_file_discovery:     
revision:        None
preserve_data_backend_cache:     False
override_dataset_config:         False
cache_dir_text:          cache
cache_dir_vae:   /data/src/SimpleTuner/ckpt/<PROJECT NAME>/cache_vae
data_backend_config:     /data/src/SimpleTuner/data/<CONFIGURATION FILE>
write_batch_size:        64
cache_dir:       /data/src/SimpleTuner/ckpt/<PROJECT NAME>/cache
cache_clear_validation_prompts:          False
caption_strategy:        textfile
instance_prompt:         None
output_dir:      /data/src/SimpleTuner/ckpt/<PROJECT NAME>
seed:    420420420
seed_for_each_device:    True
resolution:      1024.0
resolution_type:         pixel
minimum_image_size:      1024.0
crop:    False
crop_style:      random
crop_aspect:     square
train_text_encoder:      False
train_batch_size:        6
num_train_epochs:        50
max_train_steps:         1000000
checkpointing_steps:     250
checkpoints_total_limit:         2
resume_from_checkpoint:          
gradient_accumulation_steps:     4
gradient_checkpointing:          True
learning_rate:   8e-07
lr_scale:        False
lr_scheduler:    cosine
lr_warmup_steps:         100000
lr_num_cycles:   1
lr_power:        0.8
use_ema:         False
ema_decay:       0.995
non_ema_revision:        None
offload_param_path:      None
use_8bit_adam:   True
use_adafactor_optimizer:         False
use_dadapt_optimizer:    False
dadaptation_learning_rate:       1.0
adam_beta1:      0.9
adam_beta2:      0.999
adam_weight_decay:       0.01
adam_epsilon:    1e-08
max_grad_norm:   1.0
push_to_hub:     False
hub_token:       None
hub_model_id:    None
logging_dir:     logs
validation_torch_compile:        False
validation_torch_compile_mode:   max-autotune
allow_tf32:      True
report_to:       tensorboard
tracker_run_name:        simpletuner-sdxl
tracker_project_name:    sdxl-training
validation_prompt:       <CUSTOM VALIDATION PROMPT>
validation_prompt_library:       False
user_prompt_library:     None
validation_negative_prompt:      blurry, cropped, ugly
num_validation_images:   1
validation_steps:        100
validation_num_inference_steps:          30
validation_resolution:   1024
validation_noise_scheduler:      euler
disable_compel:          False
enable_watermark:        False
mixed_precision:         bf16
local_rank:      -1
enable_xformers_memory_efficient_attention:      True
set_grads_to_none:       True
noise_offset:    0.1
noise_offset_probability:        0.25
validation_guidance:     7.5
validation_guidance_rescale:     0.0
validation_randomize:    False
validation_seed:         42
fully_unload_text_encoder:       False
freeze_encoder_before:   12
freeze_encoder_after:    17
freeze_encoder_strategy:         after
print_filenames:         False
print_sampler_statistics:        False
metadata_update_interval:        65
debug_aspect_buckets:    False
debug_dataset_loader:    False
freeze_encoder:          True
text_encoder_limit:      25
prepend_instance_prompt:         False
only_instance_prompt:    False
caption_dropout_probability:     0.1
input_perturbation:      0
input_perturbation_probability:          0.25
delete_unwanted_images:          False
delete_problematic_images:       False
offset_noise:    False
lr_end:          4e-7

resume_from_checkpoint is False so that the condition on L847 in train_sdxl.py gets activated and a new training run starts. I comment out this in train_sdxl.sh so that I can set resume_from_checkpoint as None

#if [ -z "${RESUME_CHECKPOINT}" ]; then
#    printf "RESUME_CHECKPOINT not set, exiting.\n"
#    exit 1
#fi
bghira commented 9 months ago

the log line for the vae cache dir is the default dir location. it's kind of confusing. the vae_cache_prefix for the data backend config is the preferred value to use. so if they're pointing to two locations, it'll make the dir for the global option and store them in the backend-specific dir

the Bytes error i believe is related to the #284 issue where get_all_captions didn't find these during pre-processing. you can see it took half a second to supposedly process 12k captions. that's incorrect, it should be more like 10-30 minutes.

test the main branch and let me know how that goes.

for RESUME_CHECKPOINT, set it to latest and it will start a new training run anyway if none exist.

bghira commented 9 months ago

for completion, the speed issue is because the text encoder is moved to the CPU before training begins, after all of the embeds are cached. but these aren't found, because the get_all_captions didn't find them. and thus it begins caching the embeds during training with the CPU because magic_prompt does have them.

bghira commented 9 months ago

@janzd is the byte fix still required?

janzd commented 9 months ago

Yeah, it still gives me an error without the type conversion.

2024-01-23 21:20:01,648 [INFO] (__main__) Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

2024-01-23 21:20:01,649 [INFO] (__main__) Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32.
2024-01-23 21:20:01,650 [INFO] (__main__) Load tokenizers
2024-01-23 21:20:03,275 [INFO] (__main__) Load text encoder 1..
2024-01-23 21:20:04,500 [INFO] (__main__) Load text encoder 2..
2024-01-23 21:20:09,375 [INFO] (__main__) Load VAE..
2024-01-23 21:20:09,659 [INFO] (__main__) Moving models to GPU. Almost there.
2024-01-23 21:20:10,431 [INFO] (__main__) Creating the U-net..
2024-01-23 21:20:11,623 [INFO] (__main__) Moving the U-net to GPU.
2024-01-23 21:20:14,389 [INFO] (__main__) Enabling xformers memory-efficient attention.
2024-01-23 21:20:14,663 [INFO] (__main__) Initialising VAE in bf16 precision, you may specify a different value if preferred: bf16, fp16, fp32, default
2024-01-23 21:20:14,766 [INFO] (__main__) Loaded VAE into VRAM.
2024-01-23 21:20:14,792 [INFO] (DataBackendFactory) Configuring text embed backend: <TEXT EMBEDS>
2024-01-23 21:20:14,794 [INFO] (TextEmbeddingCache) (id=<TEXT EMBEDS>) Listing all text embed cache entries
2024-01-23 21:20:14,795 [INFO] (DataBackendFactory) Pre-computing null embedding for caption dropout
2024-01-23 21:20:14,797 [INFO] (DataBackendFactory) Completed loading text embed services.                                   
2024-01-23 21:20:14,797 [INFO] (DataBackendFactory) Configuring data backend: <ID>
2024-01-23 21:20:14,798 [INFO] (DataBackendFactory) (id=<ID>) Loading bucket manager.
2024-01-23 21:20:14,808 [INFO] (DataBackendFactory) (id=<ID>) Refreshing aspect buckets.
2024-01-23 21:20:26,668 [INFO] (DataBackendFactory) (id=<ID>) Reloading bucket manager cache.
(Rank: 0)  | Bucket     | Image Count 
------------------------------
(Rank: 0)  | 1.0        | 12504       
2024-01-23 21:20:28,393 [INFO] (DataBackendFactory) (id=<ID>) Initialise text embed pre-computation. We have 12584 captions to process.
Processing prompts:   0%|                                                                          | 0/12584 [00:00<?, ?it/s]2024-01-23 21:20:28,766 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a young boy. He has short hair and is wearing glasses. His eyes are green. He is wearing a white hat, a blue jacket, and a black backpack.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 172, in encode_sdxl_prompt
    text_inputs = tokenizer(
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

2024-01-23 21:20:28,766 [ERROR] (__main__) text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)., traceback: Traceback (most recent call last):
  File "/data/src/SimpleTuner/train_sdxl.py", line 429, in main
    configure_multi_databackend(
  File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 439, in configure_multi_databackend
    init_backend["text_embed_cache"].compute_embeddings_for_prompts(
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 294, in compute_embeddings_for_prompts
    return self.compute_embeddings_for_sdxl_prompts(
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 358, in compute_embeddings_for_sdxl_prompts
    prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompts(
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 234, in encode_sdxl_prompts
    prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompt(
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 217, in encode_sdxl_prompt
    raise e
  File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 172, in encode_sdxl_prompt
    text_inputs = tokenizer(
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
    raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

Also, when I start the training now and it doesn't fail on the type (when I decode it to string), it then seems to get stuck in some writing loop. It just keeps printing Waiting for batch write thread to finish, 1 items left in queue. and I have to kill the process.

...
Processing prompts:  85%|████████████████████████████████████████████████████▌         | 10679/12584 [04:40<00:49, 38.67it/s]2024-01-23 21:09:24,247 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts:  86%|█████████████████████████████████████████████████████▍        | 10839/12584 [04:45<00:45, 38.68it/s]2024-01-23 21:09:28,375 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star']
Processing prompts:  87%|██████████████████████████████████████████████████████        | 10979/12584 [04:48<00:41, 38.24it/s]2024-01-23 21:09:32,116 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts:  88%|██████████████████████████████████████████████████████▍       | 11047/12584 [04:50<00:40, 37.95it/s]2024-01-23 21:09:33,897 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['a brown jacket and a brown hat. she is wearing a brown jacket and a brown hat. she is wearing a brown jacket and a brown hat. she is wearing a brown jacket']
Processing prompts:  89%|███████████████████████████████████████████████████████▏      | 11211/12584 [04:54<00:35, 38.87it/s]2024-01-23 21:09:38,154 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['white jacket. she is wearing a pink and white jacket. she is wearing a pink and white jacket. she is we']
Processing prompts:  92%|████████████████████████████████████████████████████████▊     | 11539/12584 [05:03<00:26, 39.17it/s]2024-01-23 21:09:46,711 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['a white shirt and a blue shirt. she is wearing a white shirt and a blue shirt. she is wearing a white shirt and a blue sh']
Processing prompts:  93%|█████████████████████████████████████████████████████████▊    | 11727/12584 [05:08<00:23, 35.78it/s]2024-01-23 21:09:51,595 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts:  94%|██████████████████████████████████████████████████████████▍   | 11865/12584 [05:11<00:18, 39.18it/s]2024-01-23 21:09:55,274 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white']
Processing prompts:  94%|██████████████████████████████████████████████████████████▍   | 11869/12584 [05:12<00:18, 38.69it/s]2024-01-23 21:09:55,379 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black,']
Processing prompts:  97%|████████████████████████████████████████████████████████████▏ | 12226/12584 [05:21<00:09, 37.70it/s]2024-01-23 21:10:04,810 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['she is wearing a red bow on her head. she is wearing a red bow around her neck. she is wearing a white dress. she is wearing a red bow on her head. she is wearing a red']
Processing prompts:  98%|████████████████████████████████████████████████████████████▌ | 12282/12584 [05:22<00:07, 39.08it/s]2024-01-23 21:10:06,301 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is']
Processing prompts:  99%|█████████████████████████████████████████████████████████████▎| 12442/12584 [05:27<00:03, 38.80it/s]2024-01-23 21:10:10,346 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['is wearing a black and purple hooded sweatshirt. she is wearing a black and']
Waiting for batch write thread to finish, 0 items left in queue.                                                             
2024-01-23 21:10:14,233 [INFO] (DataBackendFactory) (id=<ID>) Completed processing 12584 captions.
2024-01-23 21:10:14,233 [INFO] (DataBackendFactory) (id=<ID>) Pre-computing VAE latent space.
2024-01-23 21:10:15,867 [INFO] (DataBackendFactory) Skipping error scan for dataset <ID>. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
2024-01-23 21:10:16,026 [INFO] (validation) Precomputing the negative prompt embed for validations.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-23 21:10:16,152 [INFO] (__main__) Moving text encoders back to CPU, to save VRAM. Currently, we cannot completely unload the text encoder.
2024-01-23 21:10:17,207 [INFO] (__main__) After nuking text encoders from orbit, we freed 1.53 GB of VRAM. The real memories were the friends we trained a model on along the way.
2024-01-23 21:10:17,207 [INFO] (__main__) Collected the following data backends: ['<ID>']
2024-01-23 21:10:17,208 [INFO] (__main__) Loading cosine learning rate scheduler with 100000 warmup steps
2024-01-23 21:10:17,214 [INFO] (__main__) Learning rate: 8e-07
2024-01-23 21:10:17,214 [INFO] (__main__) Using 8bit AdamW optimizer.
2024-01-23 21:10:17,214 [INFO] (__main__) Optimizer arguments, weight_decay=0.01 eps=1e-08
2024-01-23 21:10:17,223 [INFO] (__main__) Loading our accelerator...
2024-01-23 21:10:17,242 [INFO] (__main__) After removing any undesired samples and updating cache entries, we have settled on 1920 epochs and 521 steps per epoch.
2024-01-23 21:10:17,370 [INFO] (__main__) After the VAE from orbit, we freed 0.0 MB of VRAM.
2024-01-23 21:10:17,390 [INFO] (__main__) ***** Running training *****
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Num batches = 2084
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Num Epochs = 1920
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Current Epoch = 1
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Instantaneous batch size per device = 6
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Gradient Accumulation steps = 4
2024-01-23 21:10:17,391 [INFO] (__main__)    -> Total train batch size (w. parallel, distributed & accumulation) = 24
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Total optimization steps = 1000000
2024-01-23 21:10:17,391 [INFO] (__main__)  -> Total optimization steps remaining = 1000000
Epoch 1/1920 Steps:   0%|                                                                        | 0/1000000 [00:00<?, ?it/s]2024-01-23 21:10:17,443 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: 
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue. 

btw there's 12504 .pt files in <DATA DIR>/vaecache and 12396 .pt files in <DATA DIR>/textembed_cache now.

bghira commented 9 months ago

after the last commit referencing this PR (feb02ae) it looks like this for me:

2024-01-23 17:22:41,779 [INFO] (ArgsParser) Default VAE Cache location: /models/training/models/cache_vae
2024-01-23 17:22:41,779 [INFO] (ArgsParser) Text Cache location: cache
2024-01-23 17:22:41,780 [INFO] (__main__) Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

2024-01-23 17:22:41,780 [INFO] (__main__) Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32.
2024-01-23 17:22:41,780 [INFO] (__main__) Load tokenizers
2024-01-23 17:22:43,131 [INFO] (__main__) Load text encoder 1..
2024-01-23 17:22:43,568 [INFO] (__main__) Load text encoder 2..
2024-01-23 17:22:44,643 [INFO] (__main__) Load VAE..
2024-01-23 17:22:44,862 [INFO] (__main__) Moving models to GPU. Almost there.
2024-01-23 17:22:45,356 [INFO] (__main__) Creating the U-net..
2024-01-23 17:22:46,570 [INFO] (__main__) Moving the U-net to GPU.
2024-01-23 17:22:48,088 [INFO] (__main__) Initialising VAE in bf16 precision, you may specify a different value if preferred: bf16, fp16, fp32, default
2024-01-23 17:22:48,136 [INFO] (__main__) Loaded VAE into VRAM.
2024-01-23 17:22:48,249 [INFO] (DataBackendFactory) Configuring text embed backend: default-text-embeds
2024-01-23 17:22:48,249 [INFO] (TextEmbeddingCache) (id=default-text-embeds) Listing all text embed cache entries
2024-01-23 17:22:48,360 [INFO] (DataBackendFactory) Pre-computing null embedding for caption dropout
2024-01-23 17:22:48,772 [INFO] (DataBackendFactory) Completed loading text embed services.
2024-01-23 17:22:48,772 [INFO] (DataBackendFactory) Configuring data backend: a-piece-of-my-heart
2024-01-23 17:22:48,773 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Loading bucket manager.
2024-01-23 17:22:48,779 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Refreshing aspect buckets.
2024-01-23 17:22:48,779 [INFO] (BucketManager) Discovering new files...
2024-01-23 17:22:48,971 [INFO] (BucketManager) No new files discovered. Doing nothing.
2024-01-23 17:22:48,981 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Reloading bucket manager cache.
(Rank: 0)  | Bucket     | Image Count 
------------------------------
(Rank: 0)  | 1.0        | 8565        
Loading captions: 100%|██████████████████| 9997/9997 [00:07<00:00, 1350.86it/s]
2024-01-23 17:22:56,387 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Initialise text embed pre-computation using the textfile caption strategy. We have 9997 captions to process.
Waiting for batch write thread to finish, 0 items left in queue.                                                             
2024-01-23 17:27:25,064 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Completed processing 9997 captions.
2024-01-23 17:27:25,064 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Pre-computing VAE latent space.
2024-01-23 17:27:25,832 [INFO] (DataBackendFactory) Skipping error scan for dataset a-piece-of-my-heart. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
Processing bucket 1.0:   8%|█████▏                                                        | 675/8115 [02:59<24:54,  4.98it/s]

you can see my dataset also excludes some images. but it's not really clearly reported why. maybe it's possible to improve that, but it's quite difficult

however, the write thread isn't getting stuck here, but it's possible my str conversion fix works differently/better to yours. there were a couple other changes I had to make.

janzd commented 9 months ago

I'm not sure the reason I got stuck was because of the string conversion because it converted all captions without an issue and created embeddings for them, but it got stuck when the actual training process started.

When I manage to stop it by CTRL+C, this is what the console output looks like.

After all the prompts are processed, the training starts, but it immediately gets stuck in compute_prompt_embeddings() in collate.py, which calls compute_single_embedding() and that calls compute_embeddings_for_sdxl_prompts() from sdxl_embeds.py.

Processing prompts:  94%|██████████████████████████████████████████████████████████▍   | 11866/12584 [04:59<00:17, 39.93it/s]2024-01-24 08:35:43,932 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white']
Processing prompts:  94%|██████████████████████████████████████████████████████████▍   | 11870/12584 [05:00<00:17, 39.82it/s]2024-01-24 08:35:44,034 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black,']
Processing prompts:  97%|████████████████████████████████████████████████████████████▏ | 12226/12584 [05:09<00:09, 39.69it/s]2024-01-24 08:35:52,978 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['she is wearing a red bow on her head. she is wearing a red bow around her neck. she is wearing a white dress. she is wearing a red bow on her head. she is wearing a red']
Processing prompts:  98%|████████████████████████████████████████████████████████████▌ | 12284/12584 [05:10<00:07, 40.02it/s]2024-01-24 08:35:54,444 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is']
Processing prompts:  99%|█████████████████████████████████████████████████████████████▎| 12440/12584 [05:14<00:03, 39.40it/s]2024-01-24 08:35:58,420 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['is wearing a black and purple hooded sweatshirt. she is wearing a black and']
Waiting for batch write thread to finish, 0 items left in queue.                                                             
2024-01-24 08:36:02,147 [INFO] (DataBackendFactory) (id=<ID>) Completed processing 12584 captions.
2024-01-24 08:36:02,147 [INFO] (DataBackendFactory) (id=<ID>) Pre-computing VAE latent space.
2024-01-24 08:36:03,831 [INFO] (DataBackendFactory) Skipping error scan for dataset <ID>. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
2024-01-24 08:36:03,977 [INFO] (validation) Precomputing the negative prompt embed for validations.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-24 08:36:04,102 [INFO] (__main__) Moving text encoders back to CPU, to save VRAM. Currently, we cannot completely unload the text encoder.
2024-01-24 08:36:04,537 [INFO] (__main__) After nuking text encoders from orbit, we freed 1.53 GB of VRAM. The real memories were the friends we trained a model on along the way.
2024-01-24 08:36:04,537 [INFO] (__main__) Collected the following data backends: ['<ID>']
2024-01-24 08:36:04,537 [INFO] (__main__) Loading cosine learning rate scheduler with 100000 warmup steps
2024-01-24 08:36:04,544 [INFO] (__main__) Learning rate: 8e-07
2024-01-24 08:36:04,544 [INFO] (__main__) Using 8bit AdamW optimizer.
2024-01-24 08:36:04,544 [INFO] (__main__) Optimizer arguments, weight_decay=0.01 eps=1e-08
2024-01-24 08:36:04,553 [INFO] (__main__) Loading our accelerator...
2024-01-24 08:36:04,570 [INFO] (__main__) After removing any undesired samples and updating cache entries, we have settled on 1920 epochs and 521 steps per epoch.
2024-01-24 08:36:04,705 [INFO] (__main__) After the VAE from orbit, we freed 0.0 MB of VRAM.
2024-01-24 08:36:04,724 [INFO] (__main__) ***** Running training *****
2024-01-24 08:36:04,724 [INFO] (__main__)  -> Num batches = 2084
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Num Epochs = 1920
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Current Epoch = 1
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Instantaneous batch size per device = 6
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Gradient Accumulation steps = 4
2024-01-24 08:36:04,725 [INFO] (__main__)    -> Total train batch size (w. parallel, distributed & accumulation) = 24
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Total optimization steps = 1000000
2024-01-24 08:36:04,725 [INFO] (__main__)  -> Total optimization steps remaining = 1000000
Epoch 1/1920 Steps:   0%|                                                                        | 0/1000000 [00:00<?, ?it/s]2024-01-24 08:36:04,767 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt: 
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.                                                             
Waiting for batch write thread to finish, 1 items left in queue.
Epoch 1/1920 Steps:   0%|                                                                        | 0/1000000 [00:43<?, ?it/s]^CTraceback (most recent call last):
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1209, in wait
    return self._wait(timeout=timeout)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1959, in _wait
    (pid, sts) = self._try_wait(0)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1917, in _try_wait
Traceback (most recent call last):
  File "/data/src/SimpleTuner/helpers/training/collate.py", line 140, in compute_prompt_embeddings
    (pid, sts) = os.waitpid(self.pid, wait_flags)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/miniconda3/envs/simpletuner/bin/accelerate", line 8, in <module>
    embeddings = list(
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    sys.exit(main())
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    yield _result_or_cancel(fs.pop())
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    args.func(args)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/launch.py", line 979, in launch_command
    return fut.result(timeout)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 453, in result
    simple_launcher(args)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/launch.py", line 625, in simple_launcher
    self._condition.wait(timeout)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 320, in wait
    process.wait()
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1222, in wait
    waiter.acquire()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/src/SimpleTuner/train_sdxl.py", line 1504, in <module>
    self._wait(timeout=sigint_timeout)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1953, in _wait
    main()
  File "/data/src/SimpleTuner/train_sdxl.py", line 1014, in main
    time.sleep(delay)
KeyboardInterrupt
    for step, batch in random_dataloader_iterator(train_backends):
  File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 648, in random_dataloader_iterator
    yield (step, next(chosen_iter))
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 417, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
  File "/data/src/SimpleTuner/helpers/training/collate.py", line 231, in collate_fn
    prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings(
  File "/data/src/SimpleTuner/helpers/training/collate.py", line 139, in compute_prompt_embeddings
    with ThreadPoolExecutor() as executor:
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 649, in __exit__
    self.shutdown(wait=True)
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/thread.py", line 235, in shutdown
    t.join()
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
KeyboardInterrupt

It works and doesn't get stuck when I set CAPTION_DROPOUT_PROBABILITY to 0, so it's something to do with the empty caption that it wants to encode embedding for at the start of training.

I'll try to play with the empty caption encoding a bit.

bghira commented 9 months ago

well, it shouldn't be encoding during training at all. an idea i had was that the filename / hash generated at pre-training and training time are different, so it is unable to locate the generated embeds, as their lookup name is incorrect anyway.

i think i'm going to put some exceptions in that keep it from encoding embeds during training

bghira commented 9 months ago

i don't know if you already did, but, try clearing out the embeds and trying on the current master branch.

bghira commented 9 months ago

solved in v0.9.0-rc4

janzd commented 9 months ago

It seems to work okay now! Thanks!