bghira / SimpleTuner

A general fine-tuning kit geared toward diffusion models.
GNU Affero General Public License v3.0
1.64k stars 148 forks source link

class-preservation target loss for LoRA / LyCORIS #1031

Open bghira opened 5 days ago

bghira commented 5 days ago

the idea is based on this pastebin entry: https://pastebin.com/3eRwcAJD

snippet:

                    if batch['prompt'][0] == "woman":
                        with torch.no_grad():
                            self.model.transformer_lora.remove_hook_from_module()
                            regmodel_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                            self.model.transformer_lora.hook_to_module()

                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                        model_output_data['target']=regmodel_output_data['predicted']
                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)
                        loss *= 1.0
                        print("\nregmodel loss:",loss)
                    else:

                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)

                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)

the idea is that we can set a flag inside the multidatabackend.json for a dataset that contains our regularisation data.

instead of training on this data as we currently do, we will instead;

instead of checking for woman in the first element's caption, the batch will come with a flag to enable this behaviour, from multidatabackend.json somehow.

this will indeed run more slowly as it runs two forward passes during training from the regularisation dataset but it has the intended effect of maintaining the original model's outputs for the given inputs, which helps substantially prevent subject bleed.

note: i'm not aware of the author of the code snippet, but i would love to give credit to whoever did create it.

example that came with the snippet:

image

requested by a user on the terminus research discord.

dxqbYD commented 4 days ago

I'm the author of this. I am not entirely convinced yet myself that this is a useful feature. It seems to limit somewhat the training of the concept you do want to change ("ohwx woman" in this sample), by insisting that the concept "woman" remains exactly the same during training.

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

Happy to help with your implementation of this!

AmericanPresidentJimmyCarter commented 4 days ago

TIPO with random seeds and temperatures can be used to generate random prompts for related concepts. It can do tags -> natural language prompt or short prompt -> long prompt.

https://huggingface.co/KBlueLeaf/TIPO-500M

Screenshot_2024-10-06_20-00-26

AmericanPresidentJimmyCarter commented 4 days ago

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

There is no need to train the text encoder for flux models, as the model is partially a large text encoder aligned to image space.

dxqbYD commented 3 days ago

as the model is partially a large text encoder aligned to image space.

source, more info?

bghira commented 3 days ago

mm-dit is this.

dxqbYD commented 4 hours ago

After running some more tests, now I do think this is worth implementing. It even works well with an empty prompt and no external reg image set - just reuse the training data set and: if batch['prompt'][0] == "":

Making this a feature that does not require data, captions or configuration otherwise. Since there is no prompt provided, it can potentially preserve multiple classes and whatever you train on.

dxqbYD commented 4 hours ago

branch here for anyone who wants to try: https://github.com/dxqbYD/OneTrainer/tree/prior_reg but it's the same code as above