lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.56k stars 644 forks source link

OpenAIVae implementation #74

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

do I see it correctly that the code fragments provided by OpenAI and the way you binded it in the vae.py file means that there is no actual codebook in the form of an explicit nn.Paramter or nn.Embeddings but the very first layer of the decoder serves as the vocabulary?

(decoder): Decoder(
    (blocks): Sequential(
      (input): Conv2d(n_in=8192, n_out=128, kw=1, use_float16=False, device=device(type='cpu'), requires_grad=False)

this would explain why I couldnt finy any oO

lucidrains commented 3 years ago

@CDitzel that's exactly my understanding Carsten :)

CDitzel commented 3 years ago

all right. So then there is not Gumbel Softmax or reparameterization trick at work at all.

Also no explicit KL loss to consider as there are no constraints placed on the "codebook" to begin with.

That would explain why the first author of the paper struggled to explain the loss formula he used in his own paper when Andrej and Justin asked him about it in clubhouse oO

lucidrains commented 3 years ago

@CDitzel hmm, I think gumbel softmax was still used. can you point me to where in the clubhouse conversation they discuss the kl loss?

CDitzel commented 3 years ago

of course Phil. Happy to be of service. Its timestamped to the correct position.

Mh you are right, they do mention Gumbel softmax in the paper but I am wondering how when and where.

I tried to finetune the given OpenAI models after downloading them with your skript but didnt manage to set "requires_grad" to True. Do you happen to know if it is possible to prohibit that and if so if OpenAI did that to their model?

lucidrains commented 3 years ago

the gumbel softmax would be used on the encoder output, just before it is sent back to the decoder (where-as the hard version uses a one-hot)

nope, you can't prohibit that! once you have the weights, the sky is the limit

CDitzel commented 3 years ago

but its hard trying to finetune the provided encoder and decoder when I still neither understoof completely their data regime mapping procedure nor the specifics of their loss function. Particularly their averaging procedure and their mean/std output

Trying simple L1 loss and backpropping through their networks yield NaN pretty soon...

CDitzel commented 3 years ago

I found that the Gumbel Softmax significantly impairs the reconstruction quality. Leaving it out and just doing the tensor contraction defaults to OpenAIs provided implementation. However, with their code I cant seem to get satisfying reconstructions with tokens via the one_hot encoding scheme. Crazy...

fractaldna22 commented 3 years ago

@CDitzel > "I found that the Gumbel Softmax significantly impairs the reconstruction quality. Leaving it out and just doing the tensor contraction defaults to OpenAIs provided implementation. However, with their code I cant seem to get satisfying reconstructions with tokens via the one_hot encoding scheme. Crazy..." \


I find that this notebook solves all the problems and makes it much easier to get a grasp on it. I feel that it's a waste of time trying to train a not very deep "discrete" VAE which is like 90 mb - when it took tech giants like 3 weeks on a million tpus to get a pretrained model which they didn't even release. The VAE they release is practically just an encoder and a decoder - there are plenty.

The REAL magic of Dall_E is not the VAE - but CLIP itself. Behind CLIP is the VIsual Transformer - VIT. I don't see ViT mentioned in this notebook at all - it focuses exclusively on the vae, which is literally just a middle man between CLIP and the model. Clip is only mentioned in passing and nothing is said about what model you should use or what the architecture for that is. Just my feel. I spent days trying to get this to do anything and i feel like those hours were in VAEn. Is this all just to get a graph? Wheres the pics lol. Where's the hedgehog violin?

Dall-E = A normalized latent space to start, text, tokenizer, mapper + ViT

Correct me if I'm wrong, but from my very amateur understanding, but having already generated hundreds of dall-e-like images in practice, this is my two cents:

The central player in this implementation is Clip.load('ViT-B/32', jit=True)' -- the VisualTransformer in this, ViT-B-32.pt 343 mb which is far larger than the 'discrete' 'tinyVAE' and far more robust, the VIT pretrained model is the real performer. It's trained on thousands of imagenet images - this is where the magic happens. As far as encoding and decoding of text to tokens, and pixels to tensor numbers -- any language model will do. You can use simple.tokenizer. you can use bert, or gpt2. You can use VAE. It's a fairly linear translation from a very small dictionary that frankly could be bigger afaic. Convert text to tokens, and then send those tokens to a visual transformer model so it knows what categories it should force-hallucinate into existence from the latent space -- reward the mapper for increasing the similarity to the categories, and penalize the mapper if the similarity falls --until the mapper gets its p*xels together. .

The latest version of this notebook even decided to drop the dall_e encoder completely and it didn't even change the result except that it uses less memory. The encoder is super inefficient in compute. Perceptor(clip.model(ViT-size-int) is just as good at it, knows better than it, and is pretrained on a lot of images you'll see in the Multi-Modal Neurons distil.pub ; VAE encoder seems redundant when you already have simple.tokenizer in the clip model doing the same thing. What you don't have is a visual perceiver that tells the pixel mapper how good it's doing until it learns that moving the latent space pixels in xyz directions achieves less loss and more similarity to the text_input, than moving the pixels in incorrect directions.

Using the pretrained ViT, with the proper setup of the temp / tau, lr (.1 works fine), ZERO_GRAD, ncols, mean, std, and .clipping min and max -- it achieves similar Image examples of the type: "Hedgehogs made out of legos" and whatever you want. This is Dall_E : It's simply CLIP, within which is embedded Google's VIsual Transformer model or a similar model trained on a HUGE dataset + a simple tokenizer and decoder.

The way forward is to train a Visual Transformer model using clip - with or without the involvement of TinyVAE (I would think a very deep vae would be better), on as much images as possible. CLIP handles the labeling of them, that's what its for! Clip CAN see why kids love cinnamon toast crunch.

The ViT itself has a built in encoder and decoder, attention, etc and maps the pixels. This Attention is ALL you need. if you have a pretrained ViT, and the ViT knows what it's looking at, and CLIP knows how concepts relate, and relay the right amount of reward and penalty to get the mapper correct itself -- then you get the image you want to see. If you only have a VAE that only knows how to encode and decode from text to token and token to array -- its still not good or much If it doesn't even know what it's looking at or even see the results.

So far based on this paper - https://www.kaggle.com/abhinand05/vision-transformer-vit-tutorial-baseline# - it looks like ViT-H-14 and ViT-L-16 are the best , achieving up to 99.74% accuracy between image and text predictions or vice versa.

The notebook i linked possibly builds the 'vae' architecture into the "perceptor' state_dict but i'm not sure on that. Take a look at it if you haven't seen it. I would really love to try loading ViT-L-16 or ViT-H-14 into perceptor / clip but so far I can't figure out how to configure the .PT with the correct keys like "version" so it rejects my attempts at loading in custom models.

fractaldna22 commented 3 years ago

ViT > VaE lol

alexisrozhkov commented 3 years ago

@CDitzel I'm not an expert on the subject, but my understanding is that DALL-E paper authors didn't use "vanilla" VQ-VAE (that requires explicit codebook that you were looking for).

Everywhere in the paper they refer to "dVAE", for instance:

We train a discrete variational autoen- coder (dVAE) to compress each 256×256 RGB image into a 32 × 32 grid of image tokens, each element of which can assume 8192 possible values.

It makes me think that they have 32x32x8192 logits as encoder output that is sampled/argmax'ed to obtain 32x32x1 token tensor - this works with the code and weights for VAE that they've shared publicly.

Also no explicit KL loss to consider as there are no constraints placed on the "codebook" to begin with.

I don't see any issue in applying KL loss to the logits I've mentioned above.

So then there is not Gumbel Softmax or reparameterization trick at work at all.

How exactly did you arrive at this conclusion?

CDitzel commented 3 years ago

wowell my thoughts are playing tricks on me, something which happens when I concern myself too long with a topic exclusively xD

I got a working gumbel softmax dVAE, but still have trouble training it on a larger, i.e. 50k samples data set.

would you be interested in talking about the specific architecture further in detail?

best regards from Germany to spb (mu girlfriend is from there ;))

alexisrozhkov commented 3 years ago

Hahah thanks :)

I'm currently working on small-scale reimplementation of the paper for fun, probably will share the results as soon as I'll be sure they are correct (surprisingly I haven't seen anything that is close enough to the paper yet)

Until then I'm open to discussing the details - this might help to "fill the blanks"

fractaldna22 commented 3 years ago

image vqvae > tiny-vae(big memory) aka dall-e the model that never was

fractaldna22 commented 3 years ago

I claim that Dall-E has no working model or else they wouldn't be able to stop making more and more examples. How could they just set it down and not update the site for months? How!!!

CDitzel commented 3 years ago

I have to agree with you on this one. OpenAIs policy of reproducible research and being honest/open about their work has been questionable, to say the least. I am going to post the vae-gumbel softmax snippet here soon as a reference and so that we can talk about it. First have to get coffee and breakfast though xD

CDitzel commented 3 years ago

all right, so this is what I have come up with so far. It closely resembles Lucids implementations but parameterizes the gumbel softmax with the distance of the encoder output (logits) to the codebook vectors (described in this paper) and akin to VQ-VAEs, but in contrast to Lucids implementation which uses the logits directly as input to the Gumbel. Phils (and Karpathys implementation) never worked for me when I rightfully included the KL loss i.e. a kl loss > 0. With this implementation the KL loss can be included as it should with a uniform prior. However, the results on a larger data set are still underwhelming and not really satisfying in terms of reconstruction quality. Maybe someone can take a look at it and assess the correctness of this implementation?

class SoftDiscretizer(nn.Module):
    def __init__(
        self,
        nTokens,
        dTokens,
        temperature,
        kl_weight,
        **kwargs,
    ):
        super().__init__()
        self.nTokens = nTokens
        self.dTokens = dTokens
        self.kl_weight = kl_weight

        self.embedding = nn.Embedding(nTokens, dTokens)

    def forward(self, z):
        B, C, H, W = z.size()
        N, D = self.embedding.weight.shape

        z_flat = rearrange(z, "b c h w -> (b h w) c")
        distances = (
            torch.sum(self.embedding.weight ** 2, dim=1)
            + torch.sum(z_flat ** 2, dim=1, keepdim=True)
            - 2 * torch.matmul(z_flat, self.embedding.weight.t())
        )
        distances = rearrange(distances, "(b h w) n -> b h w n", h=H, w=W)

        # minus so that closer codebook vectors have higher probability?
        samples = F.gumbel_softmax(-distances, temperature=0.5, hard=False, dim=-1)

        if not self.training:
            tokens = samples.argmax(dim=-1)
            return tokens.flatten(start_dim=1)

        z_q = einsum("b h w n, n d -> b d h w", samples, self.embedding.weight)

        # KL loss
        logits = F.log_softmax(-distances, dim=-1)
        probs = torch.exp(logits)  # supposed to be numerically more stable than softmax alone
        neg_entropy = torch.sum(probs * (logits + math.log(self.nTokens)), dim=(1, 2, 3))
        kl_loss = self.kl_weight * torch.mean(neg_entropy)

        return z_q, kl_loss
afiaka87 commented 3 years ago

image vqvae > tiny-vae(big memory) aka dall-e the model that never was @fractaldna22 @CDitzel

Do you have a method for implementing the "image mask" feature? I've done extensive scraping of that blog post in particular (check the discussions tab for a scrape of all 1.1 million image text pairs) and they aren't always clear about when they're using that feature.

It's possible they're using a mask containing white pixels at the sides and a transparent square cropped in the center.

fractaldna22 commented 3 years ago

Yes totally. And I don't see the words "white background" in the prompt

On Wed, Apr 7, 2021, 10:11 AM afiaka87 @.***> wrote:

[image: image] https://user-images.githubusercontent.com/37323518/113177250-4e474d00-921b-11eb-935d-23d4eee0a394.png vqvae > tiny-vae(big memory) aka dall-e the model that never was

Do you have a method for implementing the "image mask" feature? I've done extensive scraping of that blog post in particular (check the discussions tab for a scrape of all 1.1 million image text pairs) and they aren't always clear about when they're using that feature.

It's possible they're using a mask containing white pixels at the sides and a transparent square cropped in the center.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/lucidrains/DALLE-pytorch/issues/74#issuecomment-814947672, or unsubscribe https://github.com/notifications/unsubscribe-auth/AI4YF7V36FIE7ZQ2KGYVVXTTHRR2LANCNFSM4Y4DFRAQ .

sidml commented 3 years ago

@CDitzel You may want to check out this issue. The gumbel softmax function expects logits as inputs, so passing distances mayn't be appropriate. The scale of distances will be quite different from the scale of gumbel distribution (which is between 0 and 1).

enhuiz commented 3 years ago

@CDitzel You may want to check out this issue. The gumbel softmax function expects logits as inputs, so passing distances mayn't be appropriate. The scale of distances will be quite different from the scale of gumbel distribution (which is between 0 and 1).

But logits are not between 0 and 1 either, could you please elaborate on "quite different from the scale of gumbel distribution"?

sidml commented 3 years ago

@enhuiz The standard Gumbel distribution i.e. G(0, 1) looks like this image

When we call F.gumbel_softmax in pytorch, it adds the provided logits with G(0, 1) distribution. If the logits are at some wierd scale, then it mayn't make sense to add it with the sample form gumbel distribution. image

More details are here

CDitzel commented 3 years ago

but just passing the encoder outputs/logits doesnt make any guarantees on the value ranges either, does it?

enhuiz commented 3 years ago

@sidml Thanks for the reply. I agree that logits should have a similar range with g_i (i.e., roughly from -2 to 6), but since the latent space is learnable, I guess the scale of the distance can be learned to suits the scale of g_i.

CDitzel commented 3 years ago

have a look at this notebook https://github.com/shaabhishek/gumbel-softmax-pytorch/blob/master/Gumbel-softmax%20visualization.ipynb

here they normalize the logits before passing them to Gumbel-Softmax. Apparently that is not necessary for the built-in Pytorch implementation.

On another note, can someone explain to me the top row of Figure 1 in the original paper?

It says expectation but how is that calculated or even depending on the temperature parameter?

sidml commented 3 years ago

@CDitzel In pytorch implementation, they seem to be directly adding the logits to the sample from gumbel distribution.

I believe they divide the logits by temperature before sampling from categorical distribution in Figure 1 of the paper.

daydreamer2023 commented 3 years ago

all right, so this is what I have come up with so far. It closely resembles Lucids implementations but parameterizes the gumbel softmax with the distance of the encoder output (logits) to the codebook vectors (described in this paper) and akin to VQ-VAEs, but in contrast to Lucids implementation which uses the logits directly as input to the Gumbel. Phils (and Karpathys implementation) never worked for me when I rightfully included the KL loss i.e. a kl loss > 0. With this implementation the KL loss can be included as it should with a uniform prior. However, the results on a larger data set are still underwhelming and not really satisfying in terms of reconstruction quality. Maybe someone can take a look at it and assess the correctness of this implementation?

class SoftDiscretizer(nn.Module):
    def __init__(
        self,
        nTokens,
        dTokens,
        temperature,
        kl_weight,
        **kwargs,
    ):
        super().__init__()
        self.nTokens = nTokens
        self.dTokens = dTokens
        self.kl_weight = kl_weight

        self.embedding = nn.Embedding(nTokens, dTokens)

    def forward(self, z):
        B, C, H, W = z.size()
        N, D = self.embedding.weight.shape

        z_flat = rearrange(z, "b c h w -> (b h w) c")
        distances = (
            torch.sum(self.embedding.weight ** 2, dim=1)
            + torch.sum(z_flat ** 2, dim=1, keepdim=True)
            - 2 * torch.matmul(z_flat, self.embedding.weight.t())
        )
        distances = rearrange(distances, "(b h w) n -> b h w n", h=H, w=W)

        # minus so that closer codebook vectors have higher probability?
        samples = F.gumbel_softmax(-distances, temperature=0.5, hard=False, dim=-1)

        if not self.training:
            tokens = samples.argmax(dim=-1)
            return tokens.flatten(start_dim=1)

        z_q = einsum("b h w n, n d -> b d h w", samples, self.embedding.weight)

        # KL loss
        logits = F.log_softmax(-distances, dim=-1)
        probs = torch.exp(logits)  # supposed to be numerically more stable than softmax alone
        neg_entropy = torch.sum(probs * (logits + math.log(self.nTokens)), dim=(1, 2, 3))
        kl_loss = self.kl_weight * torch.mean(neg_entropy)

        return z_q, kl_loss

For test procedure, may it need to make some changes here?

        if not self.training:
            # tokens = samples.argmax(dim=-1)
            tokens = distance.argmin(dim=-1)
            return tokens.flatten(start_dim=1)

By the way, if you want a weighted sum, why not directly using softmax rather than gumbelsoftmax ?

# z_q = einsum("b h w n, n d -> b d h w", samples, self.embedding.weight)
z_q = einsum("b h w n, n d -> b d h w", F.softmax(-distances, dim=-1), self.embedding.weight)
thuangb commented 2 years ago

I also tried all existing gumbel quantizer but they all perform worse than normal quantizer. But I've recently notice from OpenAI's code for DALL-E, the encoder outputs the probability and take that directly as the input for the decoder without using any codebook embedding. Which leads me to following dummy gumbel quantizer and I notice that with some hyper-parameter tweaking, it even works better than or on par with previous gumbel quantizer implementations:

class DummyDiscretizer
     def __init__(
        self,
        z_dim,
        nTokens,
        temperature,
        **kwargs,
       ):

       super().__init__()
       self.z_dim = z_dim

       self.nTokens = nTokens
       self.temp = temperature

       self.proj_in =  nn.Conv2d(z_dim, nTokens, kernel_size=1)
       self.proj_out =  nn.Conv2d(nTokens, z_dim, kernel_size=1)

    def forward(self, z):
        B, C, H, W = z.size()

        logits = rearrange(self.proj_in(z).exp(), "b c h w -> b (h w) c")
        probs = F.gumbel_softmax(logits, temperature=self.temp, hard=False, dim=-1)

        probs = rearrange(probs, "b (h w) c -> b c h w", w=z.shape[-1])
        z_q = self.proj_out(probs) 

        # Dummy Loss, can improve more with OpenAI myterious kl loss =)))
        kl_loss = torch.Tensor(0)

        return z_q, kl_loss

I think this means that something is wrong with the KL loss because it is almost useless