huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.59k stars 5.29k forks source link

Train text encoder in train_text_to_image_sdxl.py #7536

Closed stpg06 closed 3 months ago

stpg06 commented 6 months ago

I have used this script successfully on a A100 40GB to train a full SDXL model, I would like to try training the text_encoder as well and that is currently not an option.

I'm wondering if that option can be added with possibly a separate learning rate argument for it. I know things are done to make things simple and basic for inexperienced users like myself and I am highly appreciative of all the hard work you guys put in here. I just see no reason why this option is confined to the LoRa scripts but training the text_encoder isn't at least optional in the full script. I think it's been well documented that training the text encoder helps in all cases. It may crash the GPU, still be nice to see if I could get it to run by tweaking other settings.

Again, thanks for all this team does, and if there is some other script that trains both the full model and text_encoder, please let me know.

bghira commented 6 months ago

tuning a low-rank adaptation network applies to parts of the text encoder, and is much easier to train and apply at different strength levels.

this requires pretty extreme quantity of vram and balloons the complexity of the scripts. the proper way to tune the text encoder is using the OpenCLIP codebase from ml_foundations. this requires a lot of training data - in the hundreds of thousands at a starting point. contrastive loss requires a lot of compute.

stpg06 commented 6 months ago

tuning a low-rank adaptation network applies to parts of the text encoder, and is much easier to train and apply at different strength levels.

this requires pretty extreme quantity of vram and balloons the complexity of the scripts. the proper way to tune the text encoder is using the OpenCLIP codebase from ml_foundations. this requires a lot of training data - in the hundreds of thousands at a starting point. contrastive loss requires a lot of compute.

I see, so it's not just as simple as the Dreambooth method like with SD1-2? In those scripts you can train the text encoder easily. I know it's a larger model, but at the same time the text encoders are much smaller than the unet model, so I wouldn't expect it to add that much to the process. You'd think it might still stick just under 40GB on an A100. Perhaps I'm just not understanding the difference between Dreambooth scripts and the train_text_to_image scripts. Also, since the capability to train SDXL text encoders when training LoRas is already there, I guess I'm just confused as to why we can't have full model training with that capability. I'm not talking about training a whole text encoder, only fine-tuning the SDXL ones, just so it's clear.

bghira commented 6 months ago

trained weights have to be stored in fp32, which increases the vram consumption a lot more over inference time

stpg06 commented 6 months ago

trained weights have to be stored in fp32, which increases the vram consumption a lot more over inference time

I have trained a lot of models on sd1.4 and 1.5 and I've noticed that after fine-tuning the model and text encoder, there is little to no difference in the size of the model when completed. I am able to run basic inference with SDXL no problem, and train the full model, depending on settings it will go anywhere from roughly 25-33GB VRAM usage. So where exactly would training the text encoder put it, do you think? Well over 40? I wouldn't think it would go that high. You can train the whole SD1.5 model, and the text encoder with it using Dreambooth script with 16GB GPU or maybe even less using gradient checkpoint, high accumulation steps, etc. I know that you can't do it on anything under an A100, but just wondering could an A100 do it...

bghira commented 6 months ago

it's more than 57GB of VRAM, even with gradient checkpointing.

stpg06 commented 6 months ago

Wow, that is high. It's too bad you can't do it in separate fine-tuning sessions. I guess then the text encoder wouldn't match the unet model, though. I accept that is too hard and memory consuming based on what you've said.

My issue with the LoRa training is that it inherits too much of the style from my images, whereas I prefer training on a subject, so the LoRa's often come out useless to me, too much memory on the style and faces, etc, not enough with the subject or complexities. SDXL seems to have adequate parameters to train on more complex subjects, but without the ability to train the text encoder to bake in fine and important details, it usually makes a mess of anything I can do with it. I guess it's just that the bigger and better these models get, the harder it is to work with them. Soon we won't be able to run inference with them at all, big companies will control them and charge you a fortune.

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 3 months ago

Closing because of inactivity.