Seanseattle / StyleSwap

StyleSwap: Style-Based Generator Empowers Robust Face Swapping (ECCV 2022)
Apache License 2.0
198 stars 18 forks source link

How to get the expression error #11

Open shiwk20 opened 1 year ago

shiwk20 commented 1 year ago

Thanks for your wonderful work! Recently I'm working on another faceswap network, but when I tried to reproduce the expression error in the paper, I met with difficulties. I use the paper A Compact Embedding for Facial Expression Similarity refered in your paper to extract expression embeddings. The code I used is AmirSh15/FECNet which is the only pytorch implementation I could find. But when I compute the L2 distance of the 16-dim embeddings of target face and swap face, I found that the result is always around 0.5(I tested DeepFakes, FaceShifter.etc), which differ greatly with the result you put on your paper. So could you please give me some details about how you compute the expression error? I would appreciate it if you could show your expression error code to your convenience.

shiwk20 commented 1 year ago

The code I used to compute expression error is like this:

def get_expr(tgt_imgs, gen_imgs, model):
    '''
    input tensor: b * 3 * h * w
    '''
    tgt_out = model(tgt_imgs) # 10 x 16
    gen_out = model(gen_imgs)

    return torch.sqrt(torch.sum((tgt_out - gen_out) ** 2, dim=1))

def test_deepfakes(model):
    df_data_root = 'data/MyFF++_no_rotation/DeepFakes/images256'
    print('test_deepfakes')
    ori_data_root = 'data/MyFF++_no_rotation/samples/images256'
    landmarks = pickle.load(open('data/MyFF++_no_rotation/landmark/landmarks256.pkl', 'rb'))

    df_video_list = os.listdir(df_data_root)
    df_video_list.sort()
    ori_video_list = os.listdir(ori_data_root)
    ori_video_list.sort()
    samples = json.load(open('data/MyFF++/samples/samples.json'))
    mtcnn = MTCNN(image_size=224)
    gen_imgs = torch.zeros((10, 3, 224,224))
    tgt_imgs = torch.zeros((10, 3, 224,224))

    ave_expr_error = 0
    count = 0
    with torch.no_grad():
        for video_idx, video in enumerate(tqdm(ori_video_list)):
            df_video = df_video_list[video_idx]
            for img_idx, img in enumerate(samples[video]):
                df_image = Image.open(os.path.join(df_data_root, df_video, img)).convert('RGB')
                ori_image = Image.open(os.path.join(ori_data_root, video, img)).convert('RGB')

                tmp = mtcnn(df_image, return_prob=False)
                gen_imgs[img_idx]= tmp
                tmp = mtcnn(ori_image, return_prob=False)
                tgt_imgs[img_idx]= tmp
                count += 1

            tmp = get_expr(tgt_imgs.to(device), gen_imgs.to(device), model)
            ave_expr_error += tmp.sum()
            print(f'{video_idx+1}/1000', ave_expr_error / count, 'all: ', tmp, 'count', count)
    print('final', ave_expr_error / count)

if __name__ == '__main__':
    device = torch.device('cuda:0')
    model = FECNet(pretrained=True)

    model.to(device)
    # Test the Model
    model.eval()  # Change model to 'eval' mode

    print('start test')
    test_deepfakes(model)
chentting commented 1 year ago

Hello, I also have this situation. Have you solved this problem

zhouzikang commented 8 months ago

Hello, I have a similar situation. Have you found a solution to this problem? @chentting @shiwk20