CompVis / stable-diffusion

A latent text-to-image diffusion model
https://ommer-lab.com/research/latent-diffusion-models/
Other
67.75k stars 10.11k forks source link

CLIPTextEmbedder #806

Closed m-violet-s closed 10 months ago

m-violet-s commented 11 months ago

Hi, I am having the following problem when running

python scripts/txt2img.py

I understand from the _cond_stageconfig in the v1-inference.yaml file that the text encoder used is called from huggingface, and I can run txt2img.py normally through this encoder, but if I set the _con_stageconfig to call the text encoder from the clip official website encoder from the clip website, it results in an image that has nothing to do with the text. while I changed

cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPEmbedder to cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPTextEmbedder in the v1-inference.yaml file.

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)            

        z = outputs.last_hidden_state 
        return z

    def encode(self, text):
        return self(text)

class FrozenCLIPTextEmbedder(nn.Module):
    """
    Uses the CLIP transformer encoder for text.
    """
    def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
        super().__init__()
        self.model, _ = clip.load(version, jit=False, device="cpu")
        self.device = device
        self.max_length = max_length
        self.n_repeat = n_repeat
        self.normalize = normalize

    def freeze(self):
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        tokens = clip.tokenize(text).to(self.device)
        z = self.model.encode_text(tokens)
        if self.normalize:
            z = z / torch.linalg.norm(z, dim=1, keepdim=True)
        return z

    def encode(self, text):
        z = self(text)
        if z.ndim==2:
            z = z[:, None, :]
        z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
        return z     

Is this because stablediffusion is trained with the huggingface text encoder, which causes the text to not work when using FrozenCLIPTextEmbedder as the encoder?

m-violet-s commented 11 months ago

If anyone knows, could you please let me know?

Shelsin commented 10 months ago

Hey, bro, I met the same problem! But I also don't know how to solve it. If my memory serves me right, the FrozenCLIPTextEmbedder works theoretically. Because I used to fine-tune the text encoder for stable diffusion.

m-violet-s commented 10 months ago

Thank you, I found that the feature vector sizes of the text generated by these two classes are different, and by looking at the code inside _self.model.encodetext(tokens) I changed the encode function inside the FrozenCLIPTextEmbedder class to

def encode(self, text):
        tokens = clip.tokenize(text).to(self.device)
        x = self.model.token_embedding(tokens).type(self.model.dtype)  # [batch_size, n_ctx, d_model]
        x = x + self.model.positional_embedding.type(self.model.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        z = self.model.ln_final(x).type(self.model.dtype)   # #  [batch, 77, 768]
        return z

I found that the image generated this way matches his text, I wonder if you did the same thing as this before?

Shelsin commented 10 months ago

Thanks for your solution. Yes, similar to yours. I set the self.normalize as False in the FrozenCLIPTextEmbedder. And I remove the x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection in the Clip.model.py. You are right. Actually, the sequence length of the condition(aka. prompts) is different in the FrozenCLIPTextEmbedder (B, 1, 768) and FrozenCLIPEmbedder(B, 77, 768). You know the feature from the FrozenCLIPTextEmbedder denotes a sentence which means 1, but the feature from the FrozenCLIPEmbedder denotes each word in the sentence which means 77.