ControlNet / MARLIN

[CVPR] MARLIN: Masked Autoencoder for facial video Representation LearnINg
https://openaccess.thecvf.com/content/CVPR2023/html/Cai_MARLIN_Masked_Autoencoder_for_Facial_Video_Representation_LearnINg_CVPR_2023_paper
Other
209 stars 20 forks source link

seems "marlin_vit_base_ytf" checkpoint is dameged or not working. #18

Closed wolverine28 closed 7 months ago

wolverine28 commented 7 months ago

I have tested all three (small, base, large) models with a reconstruction task with the code below (sorry for the messy code, it's just for quick testing) :)


import cv2
import numpy as np
import torch
import torchvision
from einops import rearrange, repeat
from marlin_pytorch import Marlin

# Load MARLIN model from GitHub Release
model = Marlin.from_online("marlin_vit_base_ytf", full_model=True)
model = model.cuda()
model.as_feature_extractor = False

im = cv2.imread('test.png')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (224, 224))
im = rearrange(im, 'h w c -> () c h w')/255
im = torch.tensor(repeat(im, 'b c h w -> b c t h w', t=16)).float().cuda()
mask = torch.rand(1, 1568).cuda()>0.9
y = model(im,mask)
print()

im_patch = im.unfold(2, 2, 2) \
            .unfold(3, 16, 16) \
            .unfold(4, 16, 16)
im_patch = rearrange(im_patch, "b c nt nh nw pt ph pw -> b (nt nh nw) (c pt ph pw)")
model.decoder.unpatch_to_img(im_patch)[:,:,0]

x = rearrange(im, "b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c",
    p0=2, p1=16, p2=16)
x = rearrange(x, "b n p c -> b n (c p)")
x_rec = x.clone()
x_rec[~mask] = y.view(-1, 16 * 16 * 2 * 3)

x_mask = x.clone()
x_mask[~mask] = 0

grid_img = torchvision.utils.make_grid(torch.cat([
rearrange(model.decoder.unpatch_to_img(x_rec), 'b c t h w -> (b t) c h w'),
rearrange(model.decoder.unpatch_to_img(x_mask), 'b c t h w -> (b t) c h w')
],2), nrow=5)
iiii = (grid_img.permute(1, 2, 0).cpu().detach().numpy()*255).astype(np.uint8)
iiii = cv2.cvtColor(iiii, cv2.COLOR_RGB2BGR)
cv2.imwrite('large.png',iiii)

only base model cannot reconstruct the original face. Any idea?

Large

2 large

Base

1 base

Small

0 small

ControlNet commented 7 months ago

Thanks for the feedback. Now it is fixed. Please clear the model cache then rerun it.

Marlin.clean_cache()
wolverine28 commented 7 months ago

👍