When I fine tune SDXL, I try to maximise the use of the 24GB of VRAM by seeing how many batches can run within the limits. What I've seen happening during training is the VRAM usage will jump from 22GB to 25GB or more, which causes the training speed to drop considerably. I can fix this by clicking the sample now, it will drop VRAM usage to around 13GB and after sampling return to the 22GB. It can also drop on it's own after a while if you leave it alone.
I'm using adafactor, constant, text encoder training switched off, EMA set to GPU, alignprop disabled, masked training disabled. None of the random crop, image variations, text variations etc are enabled. Only aspect ratio bucketing and latent caching enabled.
What did you expect would happen?
The amount of VRAM usage to stay constant during training.
We believe this is fixed with the latest update (we suspect it was an issue in Torch). Please reopen if, after updating to today's master branch and running update.bat to update all dependencies, it reoccurs.
What happened?
When I fine tune SDXL, I try to maximise the use of the 24GB of VRAM by seeing how many batches can run within the limits. What I've seen happening during training is the VRAM usage will jump from 22GB to 25GB or more, which causes the training speed to drop considerably. I can fix this by clicking the sample now, it will drop VRAM usage to around 13GB and after sampling return to the 22GB. It can also drop on it's own after a while if you leave it alone.
I'm using adafactor, constant, text encoder training switched off, EMA set to GPU, alignprop disabled, masked training disabled. None of the random crop, image variations, text variations etc are enabled. Only aspect ratio bucketing and latent caching enabled.
What did you expect would happen?
The amount of VRAM usage to stay constant during training.
Relevant log output
No response
Output of
pip freeze
absl-py==2.1.0 accelerate==0.25.0 aiohttp==3.9.3 aiosignal==1.3.1 antlr4-python3-runtime==4.9.3 async-timeout==4.0.3 attrs==23.2.0 bitsandbytes==0.41.1 cachetools==5.3.3 certifi==2024.2.2 charset-normalizer==3.3.2 colorama==0.4.6 coloredlogs==15.0.1 customtkinter==5.2.1 dadaptation==3.2 darkdetect==0.8.0 -e git+https://github.com/kashif/diffusers.git@a3dc21385b7386beb3dab3a9845962ede6765887#egg=diffusers filelock==3.13.1 flatbuffers==24.3.7 frozenlist==1.4.1 fsspec==2024.2.0 ftfy==6.1.3 google-auth==2.28.2 google-auth-oauthlib==1.2.0 grpcio==1.62.1 huggingface-hub==0.20.3 humanfriendly==10.0 idna==3.6 importlib_metadata==7.0.2 invisible-watermark==0.2.0 Jinja2==3.1.3 lightning-utilities==0.10.1 lion-pytorch==0.1.2 Markdown==3.5.2 MarkupSafe==2.1.5 -e git+https://github.com/Nerogar/mgds.git@5213539e33a7a650961e6f4d5faf51ce578af85a#egg=mgds mpmath==1.3.0 multidict==6.0.5 networkx==3.2.1 numpy==1.26.2 oauthlib==3.2.2 omegaconf==2.3.0 onnxruntime==1.15.1 onnxruntime-gpu==1.16.3 open-clip-torch==2.23.0 opencv-python==4.8.1.78 packaging==23.2 pillow==10.2.0 platformdirs==4.2.0 pooch==1.8.0 prodigyopt==1.0 protobuf==4.23.4 psutil==5.9.8 pyasn1==0.5.1 pyasn1-modules==0.3.0 pyreadline3==3.4.1 pytorch-lightning==2.1.3 PyWavelets==1.5.0 PyYAML==6.0.1 regex==2023.12.25 requests==2.31.0 requests-oauthlib==1.3.1 rsa==4.9 safetensors==0.4.1 scipy==1.12.0 sentencepiece==0.2.0 six==1.16.0 sympy==1.12 tensorboard==2.15.1 tensorboard-data-server==0.7.2 timm==0.9.16 tokenizers==0.15.2 torch==2.1.2+cu118 torchmetrics==1.3.1 torchvision==0.16.2+cu118 tqdm==4.66.1 transformers==4.36.2 typing_extensions==4.10.0 urllib3==2.2.1 wcwidth==0.2.13 Werkzeug==3.0.1 xformers==0.0.23.post1+cu118 yarl==1.9.4 zipp==3.17.0