bghira / SimpleTuner

A general fine-tuning kit geared toward diffusion models.
GNU Affero General Public License v3.0
1.78k stars 168 forks source link

Training crashing for images that used to work fine, with error: latent shape mismatch: torch.Size([16, 88, 136]) != torch.Size([16, 88, 144]) #977

Closed a-l-e-x-d-s-9 closed 3 weeks ago

a-l-e-x-d-s-9 commented 1 month ago

Using latest version of SimpleTuner producing an error: "latent shape mismatch: torch.Size([16, 88, 136]) != torch.Size([16, 88, 144])" Examples: error_crash_latent_mismatch_01.txt error_crash_latent_mismatch_02.txt

Both images used in the past training without issues. I have the option in settings: "--delete_problematic_images": true Tested with two separate servers with clean caching for each. multidatabackend: s01_multidatabackend.json Attaching images that crashed so far, zip password is "password": images_crashed.zip

Interestingly, I had a crash on step 607 now. With checkpoint made on step 600. I continued the training from last, to see if it will crash on the same step, but training passed over step 607 without crashing.

a-l-e-x-d-s-9 commented 1 month ago

config.json: config.json

a-l-e-x-d-s-9 commented 1 month ago

I've changed the multidatabackend.json to use pixel_area: s01_multidatabackend.json Also I deleted all the old cache, and the training cached everything. I'm getting an error on a different image now: error_crash_latent_mismatch_03.txt

bghira commented 1 month ago

you should disable disable-bucket-pruning and see

a-l-e-x-d-s-9 commented 1 month ago

I assume that the code generating the cache and the code verifying the cache have a difference in rounding and calculation of the sizes, which is causing the mismatch. From what I see in the code related to disable_bucket_pruning, it just doesn’t remove images, it doesn’t have a different calculation for the image sizes. Maybe not using disable_bucket_pruning can hide the issue, but I don’t think it’s part of the problem or the solution.

bghira commented 1 month ago

if it goes away then i'll know where to look, otherwise i have to assume this is a problem with just your setup and can't do anything about it. it's up to you how to proceed

bghira commented 1 month ago

there's no difference in rounding between cache generation and cache loading. the actual size of the cache element is checked against other cache entries in the batch.

if you really want to just keep training with all other settings the same, use a batch size of 1 with gradient accumulation steps > 1 to emulate larger batch sizes with the appropriate slowdown?

a-l-e-x-d-s-9 commented 1 month ago

Ok, I'm testing now with disable_bucket_pruning.

a-l-e-x-d-s-9 commented 1 month ago

I tested the training with disable_bucket_pruning=false. And I used repeats=3 with dataset. The training crashed after 227 steps with an error:

 Epoch 1/1, Steps:   1%|▏                                | 227/30000 [45:30<99:29:22, 12.03s/it, lr=1.67e-5, step_loss=0.0805]████████████████████████████████▋                        | 5/6 [01:29<00:17, 17.90s/it]
 (id=all_dataset_512) File /workspace/input/dataset/npa/66546318_003_043d.jpg latent shape mismatch: torch.Size([16, 56, 80]) != torch.Size([16, 48, 80])
 Traceback (most recent call last):
   File "/workspace/SimpleTuner/train.py", line 49, in <module>
     trainer.train()
   File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1612, in train
     batch = iterator_fn(step, *iterator_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1269, in random_dataloader_iterator
     return next(chosen_iter)
            ^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
     data = self._next_data()
            ^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
     return self.collate_fn(data)
            ^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 882, in <lambda>
     collate_fn=lambda examples: collate_fn(examples),
                                 ^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
     latent_batch = check_latent_shapes(
                    ^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
     raise ValueError(
 ValueError: (id=all_dataset_512) File /workspace/input/dataset/npa/66546318_003_043d.jpg latent shape mismatch: torch.Size([16, 56, 80]) != torch.Size([16, 48, 80])

The exact settings to reproduce this issue: tr_01 Dataset to reproduce the issue.

a-l-e-x-d-s-9 commented 1 month ago

The attached dataset and settings are reproducing the issue very reliably in a few hundreds steps. I don't think "1 / 0 magic" needed.

bghira commented 1 month ago

i still can't reproduce it at all on mac or linux, hence not having a solution yet. i set up your dataset in combination with 12 other datasets containing roughly 3800 to 13000 images in each, plus one with 576,000 images in it and there's no problems locally.

you will probably have to enable SIMPLETUNER_LOG_LEVEL=DEBUG in config.env and reproduce it with the details in debug.log and review the contents to determine why the sizes are going AWOL.

a-l-e-x-d-s-9 commented 1 month ago

Using additional datasets is something that might hide the issue, because you will have more buckets having more images. And you will have a smaller chance to get buckets from my dataset that causing the crash. The dataset that I provided always crashes - and I tested it with a lot of different servers multiple times. I think that if you use my settings for multidatabackend + config + dataset - without extras, you will be able to reproduce the problem after a few hundreds steps. I can run it again with debug enabled, but it might be also related to timing - I had another issue with crashes that just disappeared when I enabled debugs. And not sure how helpful the debug log actually will be for you. But I will do it if it can help. I wanted to train with this dataset for a while, and it crashes every time with a different image each time - images that used to work with slightly different multidatabackend configurations and less images. And I don't have any idea for workaround.

bghira commented 1 month ago

it is just a dataset of images. there are not two images with the same name and different sizes. having the one dataset didn't cause the error either. i do in fact train quite frequently with just one dataset, and your configuration here has two datasets.

a-l-e-x-d-s-9 commented 1 month ago

My configurations have two resolutions training on the same dataset that I attached in the dataset file. Is there something wrong with my settings, that doesn't work for you?

bghira commented 1 month ago

no, just pointing out that when you mention using additional datasets would somehow hide the issue, your config has the two already. not sure what you meant by hiding the issue with more buckets - a dataset is independent from the other datasets. i didn't run the accelerator, so i am able to get through 6000 steps quite quickly. it just validates shapes, and does not train. everything else is the same.

a-l-e-x-d-s-9 commented 1 month ago

You said dataset in combination with 12 other datasets I assumed you have literally combined my dataset with 12 other datasets - which will lower the chance of my dataset showing the error. How can you turn off accelerator and just test the rest? Is it possible it related to cropping? - I have it off, but maybe...

a-l-e-x-d-s-9 commented 1 month ago

I started a new training with smaller dataset and got the error:

Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/face/479ea11b-721e-4c8a-aad1-33ad3f56e1d2.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 144, 112])

Debug log: debug.log Settings: s01_config_01.json s01_multidatabackend.json

a-l-e-x-d-s-9 commented 1 month ago

Here is a debug log - I deleted the cache first: debug2.log It crashed after 10 steps this time.

a-l-e-x-d-s-9 commented 1 month ago

Error:

/dataset/s19/72a10a68-c6ad-45d6-a5d6-4ba1ae60bbb7.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s19/72a10a68-c6ad-45d6-a5d6-4ba1ae60bbb7.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])
a-l-e-x-d-s-9 commented 1 month ago

A new run with: "--aspect_bucket_rounding": 2, Error:

Epoch 1/13, Steps:   0%|                                    | 3/3500 [01:09<19:37:19, 20.20s/it, lr=0.00496, step_loss=0.488](id=all_dataset_768) File /workspace/input/s01_sandwich_master/dataset/s09/6e9279eb-dc17-4a4d-b20f-694479a5445e.jpeg latent shape mismatch: torch.Size([16, 128, 72]) != torch.Size([16, 104, 88])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_768) File /workspace/input/s01_sandwich_master/dataset/s09/6e9279eb-dc17-4a4d-b20f-694479a5445e.jpeg latent shape mismatch: torch.Size([16, 128, 72]) != torch.Size([16, 104, 88])

Log: debug3.log

a-l-e-x-d-s-9 commented 1 month ago

I uploaded the dataset.

a-l-e-x-d-s-9 commented 1 month ago

I added: "--debug_aspect_buckets": true Error:

Epoch 1/13, Steps:   0%|                                                    | 0/3500 [00:04<?, ?it/s, lr=0, step_loss=0.0107](id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s09/a33c672e-0a27-4528-a12b-9173f77d55ec.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s09/a33c672e-0a27-4528-a12b-9173f77d55ec.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])

Log: debug4.log

a-l-e-x-d-s-9 commented 1 month ago

My settings: s01_config_01.json s01_multidatabackend.json

a-l-e-x-d-s-9 commented 1 month ago

I tried to run with crop enabled, and it crashed anyway with this error:

sandwiches/French Dipped Sandwiches.jpeg latent shape mismatch: torch.Size([16, 128, 96]) != torch.Size([16, 152, 80])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1885, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1289, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 898, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_896) File /workspace/input/s01_sandwich_master/dataset/sandwiches/French Dipped Sandwiches.jpeg latent shape mismatch: torch.Size([16, 128, 96]) != torch.Size([16, 152, 80])

Here is the s01_multidatabackend.json I used. Unfortunately, enabling crop not helping with this issue. This problem is reproducible very consistently on the first steps of the training with a small dataset.

a-l-e-x-d-s-9 commented 3 weeks ago

I updated, simplified settings, converted to SDXL, and tried again with reduced dataset that has only 41 images. Here is the dataset (v4): dataset_v4.zip

I'm attaching all the files I used for training with all the settings, including full debug log per each run. Interestingly, I ran it 3 times with the same settings, but it crashed every single time on a different step with different file, on steps: 6, 7, and 14.

Here is the crash log 1 - all_settings_1.zip

Epoch 1/182, Steps:   0%|                                   | 6/4000 [00:23<4:13:02,  3.80s/it, lr=0.00482, step_loss=0.0243](id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s10_f2922403-2628-46ee-a993-25affcebe234.jpeg latent shape mismatch: torch.Size([4, 88, 48]) != torch.Size([4, 72, 56])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s10_f2922403-2628-46ee-a993-25affcebe234.jpeg latent shape mismatch: torch.Size([4, 88, 48]) != torch.Size([4, 72, 56])

Epoch 1/182, Steps:   0%|                                   | 6/4000 [00:23<4:22:54,  3.95s/it, lr=0.00482, step_loss=0.0243]

Here is the crash log 2 - all_settings_2.zip

Epoch 1/182, Steps:   0%|                                   | 7/4000 [00:26<4:05:07,  3.68s/it, lr=0.00476, step_loss=0.0177](id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/s17/33caf417-bf48-4cb5-84fc-ec0c7df9401a.jpeg latent shape mismatch: torch.Size([4, 72, 56]) != torch.Size([4, 88, 48])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/s17/33caf417-bf48-4cb5-84fc-ec0c7df9401a.jpeg latent shape mismatch: torch.Size([4, 72, 56]) != torch.Size([4, 88, 48])

Epoch 1/182, Steps:   0%|                                   | 7/4000 [00:26<4:16:20,  3.85s/it, lr=0.00476, step_loss=0.0177]

Here is the crash log 3 - all_settings_3.zip

Epoch 1/182, Steps:   0%|                                  | 14/4000 [00:53<4:10:38,  3.77s/it, lr=0.00409, step_loss=0.0438](id=all_dataset_768) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s01_89659378-050c-459a-86f5-0bc5d2abb2d5.jpeg latent shape mismatch: torch.Size([4, 128, 72]) != torch.Size([4, 104, 88])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_768) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s01_89659378-050c-459a-86f5-0bc5d2abb2d5.jpeg latent shape mismatch: torch.Size([4, 128, 72]) != torch.Size([4, 104, 88])

Epoch 1/182, Steps:   0%|                                  | 14/4000 [00:53<4:12:18,  3.80s/it, lr=0.00409, step_loss=0.0438]
bghira commented 3 weeks ago

ok. i am in the mood for some pain i guess after dinner. i will proverbially dig in after i literally dig in

a-l-e-x-d-s-9 commented 3 weeks ago

Thank you! I think it should be easy peasy for you to reproduce 😊, with the small dataset and all settings, you can change the paths in files and it should work.

bghira commented 3 weeks ago

fixed by #1076 locally here

a-l-e-x-d-s-9 commented 3 weeks ago

I tested the PR with fix, with a small and medium dataset, they both finished the first epoch without crashing, so I think that the fix is working.