ControlNet / MARLIN

[CVPR] MARLIN: Masked Autoencoder for facial video Representation LearnINg
https://openaccess.thecvf.com/content/CVPR2023/html/Cai_MARLIN_Masked_Autoencoder_for_Facial_Video_Representation_LearnINg_CVPR_2023_paper
Other
222 stars 20 forks source link

How to use the feature to do the downstream task? #1

Closed ZJ-CAI closed 1 year ago

ZJ-CAI commented 1 year ago

It is really a fantastic work. I`d like to use it to imporove the quality of the video generated by Wav2Lip. Would you please show me how to utilize the feature extracted by your work?

ControlNet commented 1 year ago

OK. The brief idea is to use pretrained MARLIN encoder to replace the wav2lip facial encoder. To achieve that, we adjust the temporal frame window from 5 to 16 to fit the MARLIN encoder shape. And we retrain the syncnet with 16 frames as well.

I will upload the code for reproducing the experiments in paper in the future.

ZJ-CAI commented 1 year ago

I have just check Wav2Lip`s facial encoder, its output size is 512, whilst the output of MARLIN is 768. Is there anything wrong during my process?

Thank you

ControlNet commented 1 year ago

Yes, the decoder is also modified to fit the dimension. And also, I resized the input image to concat decoder feature maps as a replacement for original unet connections.

class Wav2LipDecoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(  # 14, 14
                Conv2d(774, 774, kernel_size=1, stride=1, padding=0, residual=True),
                Conv2d(774, 774, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2dTranspose(774, 378, kernel_size=4, stride=2, padding=1),  # 28, 28
            ),
            nn.Sequential(
                Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2dTranspose(384, 250, kernel_size=4, stride=2, padding=1),  # 56, 56
            ),
            nn.Sequential(
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2dTranspose(256, 122, kernel_size=4, stride=2, padding=1),  # 112, 112
            ),
            nn.Sequential(
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2dTranspose(128, 58, kernel_size=4, stride=2, padding=1),  # 224, 224
            ),
            nn.Sequential(
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),  # 224, 224
            )
        ])

        self.output_block = nn.Sequential(Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            Sigmoid()
        )

        self.sizes = [(14, 14), (28, 28), (56, 56), (112, 112), (224, 224)]

    def forward(self, x, img):
        img = rearrange(img, "b c t h w -> (b t) c h w")

        for i, block in enumerate(self.blocks):
            x = torch.cat([x, F.interpolate(img, self.sizes[i])], dim=1)
            x = block(x)
        x = self.output_block(x)
        return x
ZJ-CAI commented 1 year ago

Many thanks!

ashok-arjun commented 1 year ago

Hi @ControlNet - thank you for your work and for releasing the code for the paper. May I know in how many days (approximately) the code and finetuned checkpoints of the fine-tuning experiments will be released?

I'm looking forward to reproduce the results and build on top of your work! Thank you!

ControlNet commented 1 year ago

@ashok-arjun Hi, I'm working on it and will release the pretraining, adaptation to downstream and evaluation code step by step.

ControlNet commented 1 year ago

@ashok-arjun Training code is ready.

ashok-arjun commented 1 year ago

Thank you for the prompt upload, @ControlNet, I'll work with this for now.

Will look forward to the adaptation and evaluation code.

Thanks a lot, @ControlNet.

Ap1075 commented 10 months ago

@ZJ-CAI were you able to do this?

ControlNet commented 10 months ago

@ashok-arjun Now the evaluation code is ready.

rainbowoldhorse commented 9 months ago

Hello author, @ControlNet I want use your Marlin model on Wav2lip. After reviewing the decoder code you provided earlier, I have a few questions that I would like to ask you,

  1. I understand that your x is the result of the Marlin encoder. Taking the first merging tensor as an example, it is (B, 8, 768,14,14), IMG consists of 16 images of the lower half of the face, stacked in the channel,so img shape is (B, 3 * 2, 8, 14, 14). The reason why I understand this is because 768 needs to be added to 6 at the channel, and I cannot think of a better explanation.Is my understanding correct? If not, what should be the original meaning and shape of "x" and "img"?

  2. In your decoder code, I cannot infer where the audio features are fused with the Marlin features. How can I integrate the audio?

Sincerely seeking your advice on the above two questions, I hope you can answer my confusion,thanks.

akmalmasud96 commented 7 months ago

Hi @ControlNet, when the adaptation to downstream code will be released?

ControlNet commented 7 months ago

@akmalmasud96 Currently, the downstream task for CelebV-HQ for action and attributes are ready.

akmalmasud96 commented 7 months ago

@ControlNet Thank you for the quick update! Could you please provide information on the status of the deepfake component? Is it ready, or do you have an estimated timeline? Additionally, is the raw code for this part available?

akmalmasud96 commented 7 months ago

Hi @ControlNet, For the DFD ( DeepfakeDetection ) task, would the same classification architecture be used ?

https://github.com/ControlNet/MARLIN/blob/23491494bb722432dcdf19f67f267aa0ccdaa48c/model/classifier.py#L14

ControlNet commented 7 months ago

@akmalmasud96 Yeah. Just reuse it for the binary classifcation and it should be fine.