LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

Some question about loss and SSL Encoder #30

Closed mapengsen closed 4 months ago

mapengsen commented 4 months ago

Hello, when training rcg, i think training is divided into two parts:

  1. Training of RDM: freeze the SSL encoder, and then calculate the MSE loss between the rep after the encoder and the rep after RDM denoising.
  2. Pixel Generation training: freeze the SSL Encoder, and then calculate the vb_loss and MSE loss of the real distribution and predicted distribution.

Does rcg use other losses? Is SSL Encoder always frozen?

Looking forward to your reply, thanks!

LTH14 commented 4 months ago

Does rcg use other losses: RDM only uses the MSE loss. Depending on the specific pixel generator, we use cross-entropy loss (MAGE), diffusion loss on image latents (DiT and LDM), or diffusion loss on pixels (ADM).

Is SSL Encoder always frozen: yes, the SSL encoder is pre-trained and always frozen during RCG's training.

mapengsen commented 4 months ago

When you use SSL Encoder, you change the head of the last layer to the "build_mlp()" function. Doesn't this "build_mlp()" function require parameter updates?

def mocov3_vit_small(proj_dim, **kwargs):
    model = moco_vits.vit_small(**kwargs)
    hidden_dim = model.head.weight.shape[1]
    del model.head  # remove original head layer

    # projectors
    model.head = build_mlp(3, hidden_dim, 4096, proj_dim)
    return model

I feel that if this "build_mlp()" function is not trained, it will not lead to confusion, because the parameters in the "build_mlp()" function have no rules.

LTH14 commented 4 months ago

The head in moco_vits.vit_small is inherited from timm's VisionTransformer, which is a single Linear layer. However, the projection head of the pre-trained moco-v3 is an MLP (module.base_encoder.head). I don't want to change the original MoCo code and re-train that model, so instead I replace the Linear head with an MLP head so that the pre-trained weights can be loaded. We always set require_parameters=False for the entire pre-trained encoder, so it is not trained.

mapengsen commented 4 months ago

For myself datasets, i use the ResNet18 as SSL Encoder, The rep dimension of the final output before FC is (bsz, 512, 1, 1), i can pass 1) rep = rep.squeeze() --> (bsz, 512) 2) rep = nn.Linear(rep_dim*2, hidden_size) --> (bsz, 384)

if rep is not None:
    rep = rep.squeeze()
    rep = self.rep_embedder(rep)
    c = t + rep
else:
    y = self.y_embedder(y, self.training)  # (N, D)
    c = t + y

To adapt to the dimension of timestep t (bsz, 384), I need to retrain the last FC layer of SSL, but the image generation I have tried is very poor. Do you have any suggestions?

LTH14 commented 4 months ago

If your own dataset is too different from ImageNet, then the SSL encoder, RDM and pixel generator all need to be trained by yourself. In this case, you don't need to stick with 384-dimension timestep embedding -- you can change it to any dimension that fits your SSL encoder.

To try to figure out which part of your pipeline has problem, I suggest you do the following: after you train a pixel generator with the ground-truth representation, test its generation quality with representations extracted from your own images and see whether the performance is good or not. If the performance is bad, then the pixel generator has some problem. If the performance is good, then try to use generated representation (generated by RDM) and see whether it is still good.

mapengsen commented 4 months ago

Thank you very much for your suggestions, I tried to use the above method to train my own dataset, but the training results are still not good (the images are not too blurry), especially the inability to generate images that match the representation based on different images (rcg Figure 6 of the paper).

I wonder if I should add SSL Encoder after the pixel generation, and let the reconstructed rep and line image enter the SSL Encoder's rep to calculate the loss. This can ensure that generated image accord with pre define image rep.

I also not freeze the SSL encoder during training, let it continue training.

Will this improve the result? ​​I'm confused now.

LTH14 commented 4 months ago

Do you use DiT? If the images are too blurry, I would suggest you take a look at the tokenizer -- it might not be trained for your dataset. Just simply encode and decode an image with the tokenizer, and check the visual quality.

mapengsen commented 4 months ago

When I don't add rep as an additional condition for my diffusion generation model, the generated result is not great. After using RCG, the result becomes bad. I have studied for many days, but still don't know where the problem lies.

mapengsen commented 4 months ago

Do you use DiT? If the images are too blurry, I would suggest you take a look at the tokenizer -- it might not be trained for your dataset. Just simply encode and decode an image with the tokenizer, and check the visual quality.

i use MDTv2: Masked Diffusion Transformer is a Strong Image Synthesizer https://arxiv.org/abs/2303.14389

mapengsen commented 4 months ago

Do you use DiT? If the images are too blurry, I would suggest you take a look at the tokenizer -- it might not be trained for your dataset. Just simply encode and decode an image with the tokenizer, and check the visual quality.

I used an LDM encoder and decoder. I've fine-tuned it on my datasets.

LTH14 commented 4 months ago

This is quite weird -- adding representation definitely won't make the performance worse. The model can learn to at least not condition on the representation. Have you tried condition on real representation from real images, and the performance is still not good?

mapengsen commented 4 months ago

yes, i have tried, i use teh real representation from real images for test my trained model. but it not gread.

i think maybe SSL is not good, i should if i should retrain SSL when training pixel generation.

the following is my implementation:

def forward(self, x, t, y, enable_mask=False, rep=None):
    """
    Forward pass of MDT.
    x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
    t: (N,) tensor of diffusion timesteps
    y: (N,) tensor of class labels
    enable_mask: Use mask latent modeling
    """
    x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    t = self.t_embedder(t)  # (N, D)

    # ============================================= rep cond =============================================

    if rep is not None:
        # 当训练的时候
        if self.training:
            rep = rep.squeeze()
            drop_rep_mask = torch.rand(x.size(0)) < self.rep_dropout_prob
            drop_rep_mask = drop_rep_mask.unsqueeze(-1).cuda().float()
            rep = drop_rep_mask * self.fake_latent + (1 - drop_rep_mask) * rep.squeeze()
        # 当推理的时候
        else:
            rep = rep.squeeze()

        # 2、将rep的维度变为隐藏层的维度
        rep = self.rep_embedder(rep)
        # 3、将rep加到timestep t上从而作为下一步的输入
        c = t + rep
    else:
        y = self.y_embedder(y, self.training)  # (N, D)
        c = t + y
LTH14 commented 4 months ago

I notice the mean of your dataset is quite large (close to 1). In this case, the pre-trained MoCo encoder on ImageNet is likely to be useless. You might need to re-train MoCo v3 on your dataset. But be careful: MoCo-v3 uses strong augmentations on color (color distortion and random grayscale). You might need to remove them, or adjust them according to your own data.

LTH14 commented 4 months ago

I would suggest not to train the SSL encoder together with the pixel generator. The SSL encoder should be trained with contrastive loss. Training it together might lead to instability.

mapengsen commented 4 months ago

the mead and std is calculate by my-self datasets:

mean = th.Tensor([0.978, 0.977, 0.979]).cuda().unsqueeze(0).unsqueeze(-1).unsqueeze(-1) std = th.Tensor([0.131, 0.136, 0.129]).cuda().unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

if i should Individual train SSL use the same mean and std same as inference?

LTH14 commented 4 months ago

If this is the mean of your dataset, then your images are likely to be quite bright. In Moco v3, they use ImageNet normalization: https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L258-L259. If you are to train your own Moco v3 on your own dataset, you could consider changing this normalization to the statistics of your own dataset. Besides, they also use color jittering and grey scale https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L264-L267 in the augmentations. Especially for color distortion, you might need to configure the jittering scale (or remove the jittering) according to your own dataset.

mapengsen commented 4 months ago

Thank you very much! I will try it.

Thank you for your hard work. :smiley::smiley::smiley: