Closed janzd closed 9 months ago
the log line for the vae cache dir is the default dir location. it's kind of confusing. the vae_cache_prefix for the data backend config is the preferred value to use. so if they're pointing to two locations, it'll make the dir for the global option and store them in the backend-specific dir
the Bytes error i believe is related to the #284 issue where get_all_captions didn't find these during pre-processing. you can see it took half a second to supposedly process 12k captions. that's incorrect, it should be more like 10-30 minutes.
test the main branch and let me know how that goes.
for RESUME_CHECKPOINT, set it to latest
and it will start a new training run anyway if none exist.
for completion, the speed issue is because the text encoder is moved to the CPU before training begins, after all of the embeds are cached. but these aren't found, because the get_all_captions didn't find them. and thus it begins caching the embeds during training with the CPU because magic_prompt does have them.
@janzd is the byte fix still required?
Yeah, it still gives me an error without the type conversion.
2024-01-23 21:20:01,648 [INFO] (__main__) Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
2024-01-23 21:20:01,649 [INFO] (__main__) Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32.
2024-01-23 21:20:01,650 [INFO] (__main__) Load tokenizers
2024-01-23 21:20:03,275 [INFO] (__main__) Load text encoder 1..
2024-01-23 21:20:04,500 [INFO] (__main__) Load text encoder 2..
2024-01-23 21:20:09,375 [INFO] (__main__) Load VAE..
2024-01-23 21:20:09,659 [INFO] (__main__) Moving models to GPU. Almost there.
2024-01-23 21:20:10,431 [INFO] (__main__) Creating the U-net..
2024-01-23 21:20:11,623 [INFO] (__main__) Moving the U-net to GPU.
2024-01-23 21:20:14,389 [INFO] (__main__) Enabling xformers memory-efficient attention.
2024-01-23 21:20:14,663 [INFO] (__main__) Initialising VAE in bf16 precision, you may specify a different value if preferred: bf16, fp16, fp32, default
2024-01-23 21:20:14,766 [INFO] (__main__) Loaded VAE into VRAM.
2024-01-23 21:20:14,792 [INFO] (DataBackendFactory) Configuring text embed backend: <TEXT EMBEDS>
2024-01-23 21:20:14,794 [INFO] (TextEmbeddingCache) (id=<TEXT EMBEDS>) Listing all text embed cache entries
2024-01-23 21:20:14,795 [INFO] (DataBackendFactory) Pre-computing null embedding for caption dropout
2024-01-23 21:20:14,797 [INFO] (DataBackendFactory) Completed loading text embed services.
2024-01-23 21:20:14,797 [INFO] (DataBackendFactory) Configuring data backend: <ID>
2024-01-23 21:20:14,798 [INFO] (DataBackendFactory) (id=<ID>) Loading bucket manager.
2024-01-23 21:20:14,808 [INFO] (DataBackendFactory) (id=<ID>) Refreshing aspect buckets.
2024-01-23 21:20:26,668 [INFO] (DataBackendFactory) (id=<ID>) Reloading bucket manager cache.
(Rank: 0) | Bucket | Image Count
------------------------------
(Rank: 0) | 1.0 | 12504
2024-01-23 21:20:28,393 [INFO] (DataBackendFactory) (id=<ID>) Initialise text embed pre-computation. We have 12584 captions to process.
Processing prompts: 0%| | 0/12584 [00:00<?, ?it/s]2024-01-23 21:20:28,766 [ERROR] (TextEmbeddingCache) Failed to encode prompt: b'The character is a young boy. He has short hair and is wearing glasses. His eyes are green. He is wearing a white hat, a blue jacket, and a black backpack.'
-> error: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
-> traceback: Traceback (most recent call last):
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 172, in encode_sdxl_prompt
text_inputs = tokenizer(
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
2024-01-23 21:20:28,766 [ERROR] (__main__) text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)., traceback: Traceback (most recent call last):
File "/data/src/SimpleTuner/train_sdxl.py", line 429, in main
configure_multi_databackend(
File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 439, in configure_multi_databackend
init_backend["text_embed_cache"].compute_embeddings_for_prompts(
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 294, in compute_embeddings_for_prompts
return self.compute_embeddings_for_sdxl_prompts(
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 358, in compute_embeddings_for_sdxl_prompts
prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompts(
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 234, in encode_sdxl_prompts
prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompt(
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 217, in encode_sdxl_prompt
raise e
File "/data/src/SimpleTuner/helpers/caching/sdxl_embeds.py", line 172, in encode_sdxl_prompt
text_inputs = tokenizer(
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2798, in __call__
encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2856, in _call_one
raise ValueError(
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
Also, when I start the training now and it doesn't fail on the type (when I decode it to string), it then seems to get stuck in some writing loop. It just keeps printing Waiting for batch write thread to finish, 1 items left in queue.
and I have to kill the process.
...
Processing prompts: 85%|████████████████████████████████████████████████████▌ | 10679/12584 [04:40<00:49, 38.67it/s]2024-01-23 21:09:24,247 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts: 86%|█████████████████████████████████████████████████████▍ | 10839/12584 [04:45<00:45, 38.68it/s]2024-01-23 21:09:28,375 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star, white bow, white star']
Processing prompts: 87%|██████████████████████████████████████████████████████ | 10979/12584 [04:48<00:41, 38.24it/s]2024-01-23 21:09:32,116 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts: 88%|██████████████████████████████████████████████████████▍ | 11047/12584 [04:50<00:40, 37.95it/s]2024-01-23 21:09:33,897 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['a brown jacket and a brown hat. she is wearing a brown jacket and a brown hat. she is wearing a brown jacket and a brown hat. she is wearing a brown jacket']
Processing prompts: 89%|███████████████████████████████████████████████████████▏ | 11211/12584 [04:54<00:35, 38.87it/s]2024-01-23 21:09:38,154 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['white jacket. she is wearing a pink and white jacket. she is wearing a pink and white jacket. she is we']
Processing prompts: 92%|████████████████████████████████████████████████████████▊ | 11539/12584 [05:03<00:26, 39.17it/s]2024-01-23 21:09:46,711 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['a white shirt and a blue shirt. she is wearing a white shirt and a blue shirt. she is wearing a white shirt and a blue sh']
Processing prompts: 93%|█████████████████████████████████████████████████████████▊ | 11727/12584 [05:08<00:23, 35.78it/s]2024-01-23 21:09:51,595 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white,']
Processing prompts: 94%|██████████████████████████████████████████████████████████▍ | 11865/12584 [05:11<00:18, 39.18it/s]2024-01-23 21:09:55,274 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white']
Processing prompts: 94%|██████████████████████████████████████████████████████████▍ | 11869/12584 [05:12<00:18, 38.69it/s]2024-01-23 21:09:55,379 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black,']
Processing prompts: 97%|████████████████████████████████████████████████████████████▏ | 12226/12584 [05:21<00:09, 37.70it/s]2024-01-23 21:10:04,810 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['she is wearing a red bow on her head. she is wearing a red bow around her neck. she is wearing a white dress. she is wearing a red bow on her head. she is wearing a red']
Processing prompts: 98%|████████████████████████████████████████████████████████████▌ | 12282/12584 [05:22<00:07, 39.08it/s]2024-01-23 21:10:06,301 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is']
Processing prompts: 99%|█████████████████████████████████████████████████████████████▎| 12442/12584 [05:27<00:03, 38.80it/s]2024-01-23 21:10:10,346 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['is wearing a black and purple hooded sweatshirt. she is wearing a black and']
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-23 21:10:14,233 [INFO] (DataBackendFactory) (id=<ID>) Completed processing 12584 captions.
2024-01-23 21:10:14,233 [INFO] (DataBackendFactory) (id=<ID>) Pre-computing VAE latent space.
2024-01-23 21:10:15,867 [INFO] (DataBackendFactory) Skipping error scan for dataset <ID>. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
2024-01-23 21:10:16,026 [INFO] (validation) Precomputing the negative prompt embed for validations.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-23 21:10:16,152 [INFO] (__main__) Moving text encoders back to CPU, to save VRAM. Currently, we cannot completely unload the text encoder.
2024-01-23 21:10:17,207 [INFO] (__main__) After nuking text encoders from orbit, we freed 1.53 GB of VRAM. The real memories were the friends we trained a model on along the way.
2024-01-23 21:10:17,207 [INFO] (__main__) Collected the following data backends: ['<ID>']
2024-01-23 21:10:17,208 [INFO] (__main__) Loading cosine learning rate scheduler with 100000 warmup steps
2024-01-23 21:10:17,214 [INFO] (__main__) Learning rate: 8e-07
2024-01-23 21:10:17,214 [INFO] (__main__) Using 8bit AdamW optimizer.
2024-01-23 21:10:17,214 [INFO] (__main__) Optimizer arguments, weight_decay=0.01 eps=1e-08
2024-01-23 21:10:17,223 [INFO] (__main__) Loading our accelerator...
2024-01-23 21:10:17,242 [INFO] (__main__) After removing any undesired samples and updating cache entries, we have settled on 1920 epochs and 521 steps per epoch.
2024-01-23 21:10:17,370 [INFO] (__main__) After the VAE from orbit, we freed 0.0 MB of VRAM.
2024-01-23 21:10:17,390 [INFO] (__main__) ***** Running training *****
2024-01-23 21:10:17,391 [INFO] (__main__) -> Num batches = 2084
2024-01-23 21:10:17,391 [INFO] (__main__) -> Num Epochs = 1920
2024-01-23 21:10:17,391 [INFO] (__main__) -> Current Epoch = 1
2024-01-23 21:10:17,391 [INFO] (__main__) -> Instantaneous batch size per device = 6
2024-01-23 21:10:17,391 [INFO] (__main__) -> Gradient Accumulation steps = 4
2024-01-23 21:10:17,391 [INFO] (__main__) -> Total train batch size (w. parallel, distributed & accumulation) = 24
2024-01-23 21:10:17,391 [INFO] (__main__) -> Total optimization steps = 1000000
2024-01-23 21:10:17,391 [INFO] (__main__) -> Total optimization steps remaining = 1000000
Epoch 1/1920 Steps: 0%| | 0/1000000 [00:00<?, ?it/s]2024-01-23 21:10:17,443 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt:
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
btw there's 12504 .pt files in <DATA DIR>/vaecache
and 12396 .pt files in <DATA DIR>/textembed_cache
now.
after the last commit referencing this PR (feb02ae) it looks like this for me:
2024-01-23 17:22:41,779 [INFO] (ArgsParser) Default VAE Cache location: /models/training/models/cache_vae
2024-01-23 17:22:41,779 [INFO] (ArgsParser) Text Cache location: cache
2024-01-23 17:22:41,780 [INFO] (__main__) Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
2024-01-23 17:22:41,780 [INFO] (__main__) Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32.
2024-01-23 17:22:41,780 [INFO] (__main__) Load tokenizers
2024-01-23 17:22:43,131 [INFO] (__main__) Load text encoder 1..
2024-01-23 17:22:43,568 [INFO] (__main__) Load text encoder 2..
2024-01-23 17:22:44,643 [INFO] (__main__) Load VAE..
2024-01-23 17:22:44,862 [INFO] (__main__) Moving models to GPU. Almost there.
2024-01-23 17:22:45,356 [INFO] (__main__) Creating the U-net..
2024-01-23 17:22:46,570 [INFO] (__main__) Moving the U-net to GPU.
2024-01-23 17:22:48,088 [INFO] (__main__) Initialising VAE in bf16 precision, you may specify a different value if preferred: bf16, fp16, fp32, default
2024-01-23 17:22:48,136 [INFO] (__main__) Loaded VAE into VRAM.
2024-01-23 17:22:48,249 [INFO] (DataBackendFactory) Configuring text embed backend: default-text-embeds
2024-01-23 17:22:48,249 [INFO] (TextEmbeddingCache) (id=default-text-embeds) Listing all text embed cache entries
2024-01-23 17:22:48,360 [INFO] (DataBackendFactory) Pre-computing null embedding for caption dropout
2024-01-23 17:22:48,772 [INFO] (DataBackendFactory) Completed loading text embed services.
2024-01-23 17:22:48,772 [INFO] (DataBackendFactory) Configuring data backend: a-piece-of-my-heart
2024-01-23 17:22:48,773 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Loading bucket manager.
2024-01-23 17:22:48,779 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Refreshing aspect buckets.
2024-01-23 17:22:48,779 [INFO] (BucketManager) Discovering new files...
2024-01-23 17:22:48,971 [INFO] (BucketManager) No new files discovered. Doing nothing.
2024-01-23 17:22:48,981 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Reloading bucket manager cache.
(Rank: 0) | Bucket | Image Count
------------------------------
(Rank: 0) | 1.0 | 8565
Loading captions: 100%|██████████████████| 9997/9997 [00:07<00:00, 1350.86it/s]
2024-01-23 17:22:56,387 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Initialise text embed pre-computation using the textfile caption strategy. We have 9997 captions to process.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-23 17:27:25,064 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Completed processing 9997 captions.
2024-01-23 17:27:25,064 [INFO] (DataBackendFactory) (id=a-piece-of-my-heart) Pre-computing VAE latent space.
2024-01-23 17:27:25,832 [INFO] (DataBackendFactory) Skipping error scan for dataset a-piece-of-my-heart. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
Processing bucket 1.0: 8%|█████▏ | 675/8115 [02:59<24:54, 4.98it/s]
you can see my dataset also excludes some images. but it's not really clearly reported why. maybe it's possible to improve that, but it's quite difficult
however, the write thread isn't getting stuck here, but it's possible my str conversion fix works differently/better to yours. there were a couple other changes I had to make.
I'm not sure the reason I got stuck was because of the string conversion because it converted all captions without an issue and created embeddings for them, but it got stuck when the actual training process started.
When I manage to stop it by CTRL+C, this is what the console output looks like.
After all the prompts are processed, the training starts, but it immediately gets stuck in compute_prompt_embeddings()
in collate.py
, which calls compute_single_embedding()
and that calls compute_embeddings_for_sdxl_prompts()
from sdxl_embeds.py
.
Processing prompts: 94%|██████████████████████████████████████████████████████████▍ | 11866/12584 [04:59<00:17, 39.93it/s]2024-01-24 08:35:43,932 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white, pink, black, white']
Processing prompts: 94%|██████████████████████████████████████████████████████████▍ | 11870/12584 [05:00<00:17, 39.82it/s]2024-01-24 08:35:44,034 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black, white, black,']
Processing prompts: 97%|████████████████████████████████████████████████████████████▏ | 12226/12584 [05:09<00:09, 39.69it/s]2024-01-24 08:35:52,978 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['she is wearing a red bow on her head. she is wearing a red bow around her neck. she is wearing a white dress. she is wearing a red bow on her head. she is wearing a red']
Processing prompts: 98%|████████████████████████████████████████████████████████████▌ | 12284/12584 [05:10<00:07, 40.02it/s]2024-01-24 08:35:54,444 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is wearing a pink heart shaped hat. she is']
Processing prompts: 99%|█████████████████████████████████████████████████████████████▎| 12440/12584 [05:14<00:03, 39.40it/s]2024-01-24 08:35:58,420 [WARNING] (TextEmbeddingCache) The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['is wearing a black and purple hooded sweatshirt. she is wearing a black and']
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-24 08:36:02,147 [INFO] (DataBackendFactory) (id=<ID>) Completed processing 12584 captions.
2024-01-24 08:36:02,147 [INFO] (DataBackendFactory) (id=<ID>) Pre-computing VAE latent space.
2024-01-24 08:36:03,831 [INFO] (DataBackendFactory) Skipping error scan for dataset <ID>. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions.
2024-01-24 08:36:03,977 [INFO] (validation) Precomputing the negative prompt embed for validations.
Waiting for batch write thread to finish, 0 items left in queue.
2024-01-24 08:36:04,102 [INFO] (__main__) Moving text encoders back to CPU, to save VRAM. Currently, we cannot completely unload the text encoder.
2024-01-24 08:36:04,537 [INFO] (__main__) After nuking text encoders from orbit, we freed 1.53 GB of VRAM. The real memories were the friends we trained a model on along the way.
2024-01-24 08:36:04,537 [INFO] (__main__) Collected the following data backends: ['<ID>']
2024-01-24 08:36:04,537 [INFO] (__main__) Loading cosine learning rate scheduler with 100000 warmup steps
2024-01-24 08:36:04,544 [INFO] (__main__) Learning rate: 8e-07
2024-01-24 08:36:04,544 [INFO] (__main__) Using 8bit AdamW optimizer.
2024-01-24 08:36:04,544 [INFO] (__main__) Optimizer arguments, weight_decay=0.01 eps=1e-08
2024-01-24 08:36:04,553 [INFO] (__main__) Loading our accelerator...
2024-01-24 08:36:04,570 [INFO] (__main__) After removing any undesired samples and updating cache entries, we have settled on 1920 epochs and 521 steps per epoch.
2024-01-24 08:36:04,705 [INFO] (__main__) After the VAE from orbit, we freed 0.0 MB of VRAM.
2024-01-24 08:36:04,724 [INFO] (__main__) ***** Running training *****
2024-01-24 08:36:04,724 [INFO] (__main__) -> Num batches = 2084
2024-01-24 08:36:04,725 [INFO] (__main__) -> Num Epochs = 1920
2024-01-24 08:36:04,725 [INFO] (__main__) -> Current Epoch = 1
2024-01-24 08:36:04,725 [INFO] (__main__) -> Instantaneous batch size per device = 6
2024-01-24 08:36:04,725 [INFO] (__main__) -> Gradient Accumulation steps = 4
2024-01-24 08:36:04,725 [INFO] (__main__) -> Total train batch size (w. parallel, distributed & accumulation) = 24
2024-01-24 08:36:04,725 [INFO] (__main__) -> Total optimization steps = 1000000
2024-01-24 08:36:04,725 [INFO] (__main__) -> Total optimization steps remaining = 1000000
Epoch 1/1920 Steps: 0%| | 0/1000000 [00:00<?, ?it/s]2024-01-24 08:36:04,767 [WARNING] (TextEmbeddingCache) Failed retrieving prompt from cache:
-> prompt:
-> filename: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt
-> error: /data/src/SimpleTuner/data/<DATA DIR>/textembed_cache/d41d8cd98f00b204e9800998ecf8427e-sdxl.pt not found.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Waiting for batch write thread to finish, 1 items left in queue.
Epoch 1/1920 Steps: 0%| | 0/1000000 [00:43<?, ?it/s]^CTraceback (most recent call last):
File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1209, in wait
return self._wait(timeout=timeout)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1959, in _wait
(pid, sts) = self._try_wait(0)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1917, in _try_wait
Traceback (most recent call last):
File "/data/src/SimpleTuner/helpers/training/collate.py", line 140, in compute_prompt_embeddings
(pid, sts) = os.waitpid(self.pid, wait_flags)
KeyboardInterrupt
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/data/miniconda3/envs/simpletuner/bin/accelerate", line 8, in <module>
embeddings = list(
File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
sys.exit(main())
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
yield _result_or_cancel(fs.pop())
File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
args.func(args)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/launch.py", line 979, in launch_command
return fut.result(timeout)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 453, in result
simple_launcher(args)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/accelerate/commands/launch.py", line 625, in simple_launcher
self._condition.wait(timeout)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 320, in wait
process.wait()
File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1222, in wait
waiter.acquire()
KeyboardInterrupt
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/data/src/SimpleTuner/train_sdxl.py", line 1504, in <module>
self._wait(timeout=sigint_timeout)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/subprocess.py", line 1953, in _wait
main()
File "/data/src/SimpleTuner/train_sdxl.py", line 1014, in main
time.sleep(delay)
KeyboardInterrupt
for step, batch in random_dataloader_iterator(train_backends):
File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 648, in random_dataloader_iterator
yield (step, next(chosen_iter))
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/data/miniconda3/envs/simpletuner/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/data/src/SimpleTuner/helpers/data_backend/factory.py", line 417, in <lambda>
collate_fn=lambda examples: collate_fn(examples),
File "/data/src/SimpleTuner/helpers/training/collate.py", line 231, in collate_fn
prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings(
File "/data/src/SimpleTuner/helpers/training/collate.py", line 139, in compute_prompt_embeddings
with ThreadPoolExecutor() as executor:
File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/_base.py", line 649, in __exit__
self.shutdown(wait=True)
File "/data/miniconda3/envs/simpletuner/lib/python3.10/concurrent/futures/thread.py", line 235, in shutdown
t.join()
File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 1096, in join
self._wait_for_tstate_lock()
File "/data/miniconda3/envs/simpletuner/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
if lock.acquire(block, timeout):
KeyboardInterrupt
It works and doesn't get stuck when I set CAPTION_DROPOUT_PROBABILITY to 0, so it's something to do with the empty caption that it wants to encode embedding for at the start of training.
I'll try to play with the empty caption encoding a bit.
well, it shouldn't be encoding during training at all. an idea i had was that the filename / hash generated at pre-training and training time are different, so it is unable to locate the generated embeds, as their lookup name is incorrect anyway.
i think i'm going to put some exceptions in that keep it from encoding embeds during training
i don't know if you already did, but, try clearing out the embeds and trying on the current master branch.
solved in v0.9.0-rc4
It seems to work okay now! Thanks!
Thanks for the recent updates. I found some more issues with types of the caption variable in the code (sometimes it's bytes while the code expects a string).
Besides that, I have some issues with speed.
The caption is in bytes format but L166 in
sdxl_embeds.py
expects a string. I added an if-condition to check the type and convert it to string if it's in bytes.I was able to run the training then, but for some reason, it is extremely slow. It takes about 150 seconds per iteration. I'm running it on A100, so it should definitely be faster.
My training data directory contains 12586 images and 12586 corresponding text files with captions. After running the script, there's 12504 .pt files in
./data/<DATA_DIR>/vaecache
(I suppose some images were skipped for some reason). There are also directories./ckpt/<PROJECT NAME>/cache
and./ckpt/<PROJECT NAME>/cache_vae
, but those are empty.The logging output says
but the cache is stored elsewhere, so I wonder if that's an issue.
Just for completion, here's my argument list
resume_from_checkpoint
is False so that the condition on L847 intrain_sdxl.py
gets activated and a new training run starts. I comment out this intrain_sdxl.sh
so that I can setresume_from_checkpoint
as None