rinongal / textual_inversion

MIT License
2.9k stars 279 forks source link

Any way to force training into ignoring background and focus just on a subject ? #49

Closed 1blackbar closed 1 year ago

1blackbar commented 2 years ago

I can only get likeness with heavy overfitting but this also heavily learns background , any way to avoid it? white background? black? or a piece of code ?

ExponentialML commented 2 years ago

Other than pre-processing your inputs and removing the background, you can try messing with the coarse _class_text and embedding_reg_weight parameters, both together or one at a time.

This should help better condition your prompts to where you want it to go. While this may not be the recommended way, it is an option that can help. This can be added (it's currently not there) under data.train.params.coarse_class_text in the config file for fine tuning.

rinongal commented 2 years ago

If your image set is small and you are not training for long, you could also try enabling per_image_tokens in the config (note that this flag appears more than once). This will also assign a second token to each image, and train with sentences of the form "a photo of with *". If your images have very different backgrounds, this encourages the network to capture the background information in the unique token for the image.

However, like pretty much all of our experiments, we only tested this with LDM, and we used fairly short training runs with limited vectors compared to what you're using. I can't guarantee that this will persist in your setup.

oppie85 commented 2 years ago
  • embedding_reg_weight will bring your prompts closer to the init_word, but should be used sparringly. Results can vary, but arbitrary low values (like 0.00001) will show you exactly how this works. Negative values brings you further away from the init_word, and positive brings you closer. A too high or low value will push your data outside of the scope of what you're trying to invert.

I've been experimenting with ridiculously high numbers (up to 60) into "embedding_reg_weight", not entirely without success although at some point it becomes really unclear how much each setting is contributing to the final result. The idea of negative values is pretty interesting though; I had been thinking about using 'exclude'-tokens where you can provide a list of tokens you don't want to match; for example, if we've only got photos in a park, add a negative value for "park" or "outdoors".

I'm looking into expanding the sample logging step with some kind of "style transfer test" where it outputs images of prompts it doesn't use for conditioning but which are useful for judging how style transfer is affected by style transfer; for example, a painting of *, * wearing medieval armor or * in dim lighting, each of which could be used to judge where in the training process the overfitting occurs.

ExponentialML commented 2 years ago

each of which could be used to judge where in the training process the overfitting occurs.

In the DreamBooth paper, one of the key features to prevent overfitting is to supervise the results using a k_ancestral sampler and use the latent images for training.

A high level implementation is you would have 5 starting images of your choice, then after x amount of epochs, you would replace those images with the ones generated via a sampler(any k, ddim, etc.) .

I'm not exactly sure how to implement the loss function into code, but I can definitely give this part a try and see if it provides any alleviation.

oppie85 commented 2 years ago

I feel like I could come up with better strategies if I knew more about how the training worked. My current understanding is as follows:

Before training, we create a dataset with a few captions that include the desired token. During training, the model generates an encoding for the token, uses that in the conditioning prompt and checks this against the encoding of the image; the closer the trained encoding is to the image, the less "loss". The training algorithm tries to minimize loss so the encondings increasingly match the image. Is this correct?

The issue lies in that the encodings of the original image contain all kinds of information we don't necessarily want it to learn but the algorithm checks against them anyway. In theory, using a conditioning caption like a * in a field would mean that the "in a field" part of the image would already be provided by the training caption and thus wouldn't end up in the encoding, but on the other hand it means that the entire caption is way more likely to match the image (for example, a duck in a field still matches the "in a field" part to keep loss on the lower end) which would introduce more randomness into the encoding generated for the token.

So in theory when calculating the loss, you'd have to check that the generated encoding matches the image, but also that the generated encoding still kind of matches the init_word. That's what embedding_reg_weight already does, but I think you'd also have to check the opposite - that the generated encoding doesn't contain the additional stuff from the conditioning prompt. So if we're trying to train the likeness of a specific person, and we're using * standing in a field as a conditioning prompt with "man" as an init word, we'd have to check that * still trends towards "man", but also that * doesn't trend towards "in a field".

In theory using complex conditioning prompts (especially image-specific ones) should yield more targeted results, but people have reported that using plain {} as the conditioning prompt generates a better encoding - which makes sense if a photo of * also leaks "a photo" into the encoding.

Perhaps the reason why this process worked better with LDM is that maybe that one had more cleanly defined encodings per vector, where for SD, one vector can contain style information, lighting information and object definitions all at the same time. Then again I'm very new to machine learning, so take any of my speculation with a large grain of salt.

ExponentialML commented 2 years ago

You can use coarse_class_text to better condition the model. I'm currently experimenting with this by using two generated images of a class, then using the mixing_prob to randomly choose a generated image during conditioning with the caption of something like toy or anything else that fits the class.

Every 500 epochs or so (when the model is evaluated), the generated images change, and this continues until the training process is completed. There is only one template that consists of a {} [class]. It's a very light implementation, but it allows for supervision of training and to possibly reduce overfitting.

rinongal commented 2 years ago

@oppie85 LDM's latent space is actually larger, and likely more expressive both due to the increased dimension and due to the fact that its encoder was jointly trained with the generator (and thus needs to be able to encode 'finer' details).

The CLIP encoder has an issue where specific words can dominate the output vector (see e.g. No token left behind or the ipod textual attack against CLIP) which seems to be a major issue here, regardless of the structure of the input prompts.

We've had some luck with prompt weighting, but this isn't very robust and requires a lot of experimentation in the inference stage.

oppie85 commented 2 years ago

The CLIP encoder has an issue where specific words can dominate the output vector (see e.g. No token left behind or the ipod textual attack against CLIP) which seems to be a major issue here, regardless of the structure of the input prompts.

Interesting; so you'd say it is actually entirely expected that the training process seeks out those encodings that overwhelm the output vector the most? That would make a lot of sense with how the process (especially with higher amounts of vectors per token) seems to gravitate to a point where it just ignores the entire prompt and why the conditioning prompt seems to have no effect half of the time. Am I correct in assuming that you can't really detect whether or not such a 'dominant' encoding is present until you actually decode the embedding to an image and see it for yourself?

It also explains why I've had some limited success with overcoming the overfitting in some very specific prompt scenarios (something like concept art of * or inspired by * has worked a few times). Prompt weighting also works, although it always seems to dminish the likeness of the person/thing I tried to textually inverse, whereas overcoming the overfitting retains the likeness but is much more hit-or-miss

rinongal commented 2 years ago

I wouldn't say expected, no. LDM does avoid this problem successfully, so it's not some universal property of text encoders. There almost certainly a way to detect and avoid this issue during training (for example by looking at attention maps). However, if I had a good universal solution on hand, I'd have already released it :)

I'm currently sidelined with other aspects of this project, but I'll likely get back to exploring this type of solution towards the end of the month.

oppie85 commented 2 years ago

I feel like I'm close to something; I've been able to train an embedding that does succesful style transfer without having to overwhelm the prompt and it seems to ignore background information a lot better. I'm not able to reproduce it consistently but maybe someone can help to develop this idea further (or to explain to me why it will never work)

for every step, I run an additional run of apply_model but on a prompt without the placeholder (for example, just a man). I then calculate the loss between the original model output and the 'base line' prompt with get_loss(model_output, base_model_output, mean=False).mean([1, 2, 3]). What I intend to do with this is to ensure that any generated encoding still is seen as a variation of the encoding for the prompt "a man". Next, I combine the original loss and "base loss"; loss = (loss * 0.6) + base_loss. I weight the original loss as slightly less than base_loss because this puts more pressure on the training process to not include anything that doesn't describe "a man". It basically does what "embedding_reg_weight" does but over multiple vectors.

This often works - from my training samples, I can see that it starts out with random people who very slowly (extremely slowly even) start to look more like my subject (myself). Other times it doesn't work - especially when I tried bumping up the loss weight from 0.6 to higher numbers so much noise is introduced into the training process that it degenerates into nightmare fuel.

I'd consider my few successes a fluke but I think there's something here; I cobbled together this solution with limited Python/machine learning experience (I basically only started looking into this stuff when SD was released) and I imagine most of this could be improved. I feel like the process of slowly nudging a multi-vector prompt towards a proper likeness in the way that "embedding_reg_weight" does has potential.

oppie85 commented 2 years ago

I've made further improvements to my proposed solution above (and to present a bit of a refined version) which I'm calling "base reinforcement":

In a simple diagram, below is what I've added to the loss calculation:

BaseReinforcement

In essence, I've added additional condition prompts that continuously check whether or not "a painting of *" is close enough to "a painting of a man" and feeds that back into the loss calculation.

To achieve this, I've added the follwing to ddpm.py in the p_losses function:

        # for example c_base is the token representation of "a painting of a man" and c_modified is the token representation of "a painting of *"
        if len(c_base) > 0 and True:
            modified_model_output = self.apply_model(noise, t, c_modified)
            base_weight = 0.3
            base_model_output = self.apply_model(noise, t, c_base)
            base_loss_simple = self.get_loss(modified_model_output, base_model_output, mean=False).mean([1, 2, 3])
            base_loss = base_loss_simple / torch.exp(logvar_t) + logvar_t
            loss = (loss.mean() * base_weight) + base_loss.mean()

        loss = self.l_simple_weight * loss

where c_base and c_modified are the token representations of the additional prompts, which are processed elsewhere in the code (I've modified the original source code too much for other experiments and it's a mess).

To make sure that the "base reinforcement" is applied strongly so that style transfer (and other modifications) are still possible after training, the original loss is weighted by 0.3, although better results might be achieved with even lower numbers. I am not exactly sure why this is yet, but I've had easily style-able embeddings even at 40 vectors, although training for too long still leads to overfitting. I've had great results at around 2000 steps and 4 batches.

EDIT: in my latest tests on a fresh repo I reversed the weighting loss + (base_weight * .5) and the results seem better; I'll have to experiment with this a little - but these tests seem to be confirming that this process works. Sadly I won't be able to work on this further over the next week but I'll be updating my own fork at https://github.com/oppie85/textual_inversion soon.

TaleirOfDeynai commented 1 year ago

This capability could probably help me.

I've got a problem where one token, "furry", is both a great descriptor of a character's form but also an art style that happens to match a lot of the dataset; it inevitably gets added to the embedding which poisons it, making it so the embedding will always generate a low-skill scribble (which is what most furry art happens to be). This poison overpowers all other styles like "photo" or "3D render", and it takes heavy prompt engineering, like invoking the adjacent style of "Disney Pixar", to actually break it out of it but this grossly constrains what styles I have access to when using the embedding.

My end goal with my own exploration into this tool is to try to get an embedding that encapsulates a general character description that can then be applied to most any style. While the embedding is poisoned like this, I can't really achieve that.

However, perhaps this idea could help. Since it poisons it with a style, I would hope that this additional reinforcement would make it realize, "hey, this token makes it into a drawing and not a photo," and so causes it to broaden the search toward tokens that do not poison the embedding with a style.

So far, I was making some progress with:

I know of the per-image token feature, but I haven't tried it yet. Would be nice if the per-image token list could be made into an effectively infinite generator to accommodate a dataset of any size, but I guess that depends on how CLIP behaves toward unfamiliar token mappings. Can it treat *1 and *2 as separate and individual tokens? :thinking:

rinongal commented 1 year ago

@TaleirOfDeynai The per_image_tokens approach wasn't helpful in any of our experiments, so I wouldn't count on it too much. Tuning the specific prompts to your set is a good way to go about it.

If you want to take it a step further, you could try to first learn a style word that reflects the styles of the images you're working with, and then use that new word in your prompts (so you'll have a 'a drawing of * in the style of @' for your prompts). This of course assumes those images share a distinct style, and that you can get a bunch of images in that style but with different characters.

rinongal commented 1 year ago

I'm closing this due to lack of activity. Feel free to reopen if you need further help.