face3d0725 / FaceExtraction

MIT License
36 stars 7 forks source link

How to run inference? #3

Open fisakhan opened 2 years ago

fisakhan commented 2 years ago

Can you please provide code to run the pre-trained model on a face image and display the extracted face region?

aesanchezgh commented 1 year ago

Yes, I would like to see this too.

junhwanjang commented 3 months ago

I got right result as in the below code. It was little bit different from the standard post-processing way.

download

# main.py
pretrained_weight_path = "./FaceExtraction/pretrained_model/epoch_16_best.ckpt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    classes=1,
    activation=None
)

weights = torch.load(pretrained_weight_path, map_location=device)

new_weights = OrderedDict()
for key in weights.keys():
    new_key = '.'.join(key.split('.')[1:])
    new_weights[new_key] = weights[key]

model.load_state_dict(new_weights)
model.eval()

image = cv2.cvtColor(cv2.imread(IMG_PATH), cv2.COLOR_BGR2RGB)
padded_image, padding_info = pad_to_square_with_info(image)

resized_image = cv2.resize(padded_image, (256, 256))
transformed_image = prepare_image(resized_image)
image_batch = transformed_image.to(device)

with torch.no_grad():
    output = model(image_batch)[0]

predicted_mask = output.squeeze(0).cpu().numpy() / 255.0
predicted_mask = np.where(predicted_mask >= 0.0, 1, 0)

original_mask = cv2.resize(predicted_mask, padded_image.shape[:2], interpolation=cv2.INTER_NEAREST)
original_mask = remove_padding(original_mask, padding_info)

print(image.shape)
segmentation_mask = np.zeros_like(image)
segmentation_mask[:, :, 0] = original_mask * 255.0
segmentation_mask[:, :, 1] = original_mask
segmentation_mask[:, :, 2] = original_mask

vis_img = image.copy()    
blended_image = cv2.addWeighted(vis_img, 0.6, segmentation_mask, 0.4, 0)

plt.imshow(blended_image)
def prepare_image(image):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    image_tensor = transform(image)
    image_batch = image_tensor.unsqueeze(0)

    return image_batch

def pad_to_square_with_info(image):
    height, width, channels = image.shape
    max_dim = max(height, width)

    top = (max_dim - height) // 2
    bottom = max_dim - height - top
    left = (max_dim - width) // 2
    right = max_dim - width - left

    padding_info = {
        'top': top,
        'bottom': bottom,
        'left': left,
        'right': right
    }

    padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])

    return padded_image, padding_info

def remove_padding(segmentation_map, padding_info):
    top = padding_info['top']
    bottom = padding_info['bottom']
    left = padding_info['left']
    right = padding_info['right']

    height, width = segmentation_map.shape[:2]

    restored_segmentation_map = segmentation_map[top:height-bottom, left:width-right]

    return restored_segmentation_map