kohya-ss / sd-scripts

Apache License 2.0
5.15k stars 860 forks source link

Training the SDXL text encoder with sdxl_train.py adds a pink / purple color to output images #948

Open medialibraryapp opened 11 months ago

medialibraryapp commented 11 months ago

Hi there,

I'm using the recently fixed SDXL text encoder training support to train the text encoder with some new terms. This works, and I can successfully use the trained terms to generate matching images.

However, I am finding that whenever I train the text encoder, my generated images (in A1111) have a very strong pink / purple color across all of the image. If I disable text encoder training, and only train the UNet, I do not see the pink / purple tint.

The issue is especially noticeable if I only train the text encoder, and not the UNet, with sdxl_train.py. If I also train the UNet, it looks as though the image learning from the UNet "undoes" some of the purple effect. However, it still does not go away completely.

What kind of issue might cause this color tinting, and is there any way to work around it?

medialibraryapp commented 11 months ago

Hmm… going by the images in this excellent article, it looks like I've overtrained the text encoder:

https://followfoxai.substack.com/p/overtrained-text-encoder-vs-overtrained

v0xie commented 11 months ago

Does your dataset have many pink or purple images? I noticed when training that colors will bleed from concepts if the dataset is biased towards images with a color.

Try training with --debiased_estimation_loss on the dev branch. Using --color_aug works too but it's much slower.

araleza commented 11 months ago

With the new fix, I've found that training the text encoder at 1e-8 works well. I think it's so sensitive compared to the unet because the unet has a huge number of pixels to play with, but there are very few words that go into a prompt. So any big changes to the text encoder will destroy it.

And the thing is, I don't think the text encoder needs much training at all to learn a new phrase. As long as it's outputting almost any unique-ish signal at all for that new phrase, the unet can recognise that pattern, and amplify it. So that's why barely any training is needed.

My tip would be to train the text encoder at 1e-8 for a while, and then switch off text encoder training entirely, and continue to train with the unet only, just like you did. As long as the text encoder is continuing to be trained, even at 1e-8, I see tiny distortions appearing in the sample output, and those go away again (assuming you've not overtrained the text encoder badly like you were suggesting you might have) when you train just the unet.

Like you, I tried training the text encoder at higher rates, and found the unet couldn't fix those entirely, only improve them to a degree.