AmirSh15 / FECNet

Facial Expression Feature Extractor
https://github.com/AmirSh15/FECNet.git
67 stars 11 forks source link

What's the data process steps for product valid feature to clusting usage? #3

Open RiweiChen opened 4 years ago

RiweiChen commented 4 years ago

after I detect a face with a bbox and landmarks, how can I align the face data to extract face expression feature?

RiweiChen commented 4 years ago

what I test now is

full_path = os.path.join(test_path, filename)
img = Image.open(full_path)
mtcnn = MTCNN(image_size=224)
face, prob = mtcnn(img, return_prob=True)
face = np.array(face)
if face.any():
    face = torch.Tensor(face).view(1,3,224,224)
    Embedding = model(face.cuda()).cpu()
    print(Embedding[0,:].numpy())
AmirSh15 commented 4 years ago

The alignment is done automatically by the MTCNN library.

AmirSh15 commented 4 years ago

test here means test samples for evaluating the network

RiweiChen commented 4 years ago

image image image I test the similarly, but the result looks not good as except (top 1 is query image, and the follow 6 face is top similarly face)

RiweiChen commented 4 years ago

the feature I extract from the model is here:

import os
import numpy as np
import torch
from models.FECNet import FECNet
from models.mtcnn import MTCNN
from PIL import Image
import cv2

def test():
    model = FECNet()
    model.load_state_dict(torch.load('data/FECNet.pt'))
    model.eval()
    with torch.no_grad():
        test_path = "/data00/chenriwei/PublicData/FaceData/RAF/RAF"
        with open("feature2.txt",'w') as fout:
            idx = 1
            for filename in os.listdir(test_path):
                full_path = os.path.join(test_path, filename)
                try:
                    img = Image.open(full_path)
                    channels = img.split()
                    if len(channels) != 3:
                        continue
                    r,g,b = channels
                    img = Image.merge("RGB", (b, g, r))
                    mtcnn = MTCNN(image_size=224)
                except:
                    # ignore all exception
                    continue
                face, prob = mtcnn(img, return_prob=True)
                idx +=1
                face = np.array(face)
                if face.any():
                    face = torch.Tensor(face).view(1,3,224,224)
                    Embedding = model(face.cuda()).cpu()
                    fout.write("{}\t{}\n".format(filename, list(Embedding[0,:].numpy())))

if __name__ == "__main__":
    test()
RiweiChen commented 4 years ago

I using flann to do the nearest neighbour search. the full code is here

from pyflann import *
import numpy as np
import flask
import web_util
import random
import logging
import cv2
import math
app = flask.Flask(__name__)

test_path = "/data00/chenriwei/PublicData/FaceData/RAF/RAF"

def norm_feature(features):
    result = []
    s = 0.0
    for f in features:
        s += f*f
    norm = math.sqrt(s)
    for f in features:
        result.append(f/norm)
    return result

flann = FLANN()

with open("feature2.txt") as f:
    dataset = []
    dataset_names = []
    testset = []
    testset_names = []
    cnt = 0
    for line in f:
        cnt +=1
        if cnt > 1000000:
            break
        item = line.split("\t")
        filename, feature = item
        feature = eval(feature)
        if "test" in filename:
            testset.append(norm_feature(feature))
            testset_names.append(filename)
        else:
            dataset.append(norm_feature(feature))
            dataset_names.append(filename)
    dataset = np.array(dataset)
    params = flann.build_index(dataset, algorithm="kmeans", branching=32, iterations=7, checks=16)
    testset = np.array(testset)
    result, dists = flann.nn_index(testset, 6, checks=params["checks"]);

@app.route('/')
def main():
        select_idx = random.randint(0, len(result))
        filenames = []
        filenames.append("empty.jpg")
        filenames.append(testset_names[select_idx])
        filenames.append("empty.jpg")
        for idx in result[select_idx]:
            filenames.append(dataset_names[idx])
        image_htmls = []
        for idx,file_ in enumerate(filenames):
            full_path = os.path.join(test_path, file_)
            image = cv2.imread(full_path)
            if image is None:
                continue
            image_html = web_util.embeding_image_2_string(image)
            image_htmls.append(image_html)
        return web_util.make_image_html(image_htmls, 500)

if __name__ == "__main__":
    logging.getLogger().setLevel(logging.INFO)
    port = int(sys.argv[1])
    app.run(host='0.0.0.0', port = port)
RiweiChen commented 4 years ago

the complete commit can look here:

commit