protossw512 / AdaptiveWingLoss

[ICCV 2019] Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression - Official Implementation
Apache License 2.0
394 stars 88 forks source link

Training code implementation #14

Open HassanAbbas92 opened 4 years ago

HassanAbbas92 commented 4 years ago

`
import matplotlib.pyplot as plt import cv2 import sys import os from PIL import Image, ImageDraw from utils.utils import fan_NME, show_landmarks, get_preds_fromhm import numpy as np from skimage import io import shutil from torch.autograd import Variable import time import copy from torch import nn import torch import math import matplotlib matplotlib.use('Agg')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, weight_map, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(
            delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * \
            torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))

def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
                save_path='./', num_landmarks=68, start_epoch=0):
    best_acc = 100
    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=0.0000001, weight_decay=0)
    loss_AW = AdaptiveWingLoss()
    for epoch in range(start_epoch, epoches + start_epoch):
        running_loss = 0
        step = 0
        total_nme = 0
        total_count = 0
        fail_count = 0
        nmes = []
        # running_corrects = 0
        step_start = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            # with torch.set_grad_enabled(True):
            for data in dataloaders[phase]:
                optimizer.zero_grad()
                total_runtime = 0
                run_count = 0

                step += 1
                # get the inputs
                inputs = data['image'].type(torch.FloatTensor)
                labels_heatmap = data['heatmap'].type(torch.FloatTensor)
                labels_boundary = data['boundary'].type(torch.FloatTensor)
                gt_landmarks = data['landmarks'].type(torch.FloatTensor)
                loss_weight_map = data['weight_map'].type(torch.FloatTensor)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.to(device)
                    labels_heatmap = labels_heatmap.to(device)
                    labels_boundary = labels_boundary.to(device)
                    loss_weight_map = loss_weight_map.to(device)
                else:
                    inputs, labels_heatmap = Variable(
                        inputs), Variable(labels_heatmap)
                    labels_boundary = Variable(labels_boundary)
                labels = torch.cat((labels_heatmap, labels_boundary), 1)
                single_start = time.time()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs, boundary_channels = model(inputs)
                    pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
                    ###
                    loss_total = loss_AW(
                        pred_labels, loss_weight_map * 10 + 1, labels)
                    ###
                    #print("Batch Loss: {:.6f}".format(loss.item()))
                    if phase == 'train':
                        loss_total.backward()
                        optimizer.step()
                batch_nme = fan_NME(
                    outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
                #print("Batch NME: {:.6f}".format(batch_nme))
                # batch_nme = 0
                total_nme += batch_nme
            epoch_nme = total_nme / dataset_sizes[phase]
            step_end = time.time()
            print(phase + ' NME: {:.6f}'.format(epoch_nme))
            if phase == 'val' and epoch_nme < best_acc:
                state = {
                    'next_epoch': epoch+1,
                    'epoch_total_nme': epoch_nme,
                    'state_dict': model.state_dict(),
                    # 'scheduler' : scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
        #nme_save_path = os.path.join(save_path, 'nme_log.npy')
        #np.save(nme_save_path, np.array(nmes))
        #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))

    #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
    return model

` @protossw512 code you please check if my training implementation is correct

mustangchavez commented 4 years ago

@HassanAbbas92 Thanks for posting this, I was looking for something to get started with training. I found a couple pretty small issues:

And thanks @protossw512 for sharing this great project.

HassanAbbas92 commented 4 years ago

@mustangchavez Thanks for correction, did you start training from scratch and achieve same results in the papper on WFLW dataset?

mustangchavez commented 4 years ago

Sorry, I haven't tried to reproduce the results as I am actually working on a different problem domain.

I found what seems to be another correction, I believe the correct boundary map to use is already part of the first output and not the "boundary_channels" output. So

pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)

should be

pred_labels = outputs[-1]
HassanAbbas92 commented 4 years ago

@mustangchavez are you sure? because i thought @protossw512 talk about it in https://github.com/protossw512/AdaptiveWingLoss/issues/12

mustangchavez commented 4 years ago

I'm not sure, no! Haha.

But I think that each "boundary_channels" is intended to be passed into the next HourGlass module, whereas the "outputs" contains the final boundary heatmap prediction. That's why the shape of the last tensor tensor in the "outputs" list is [batch_size, num_landmarks + 1, heatmap_size, heatmap_size]

HassanAbbas92 commented 4 years ago

You are correct @mustangchavez But also i will try to visualize both of them to see the difference

vuthede commented 4 years ago

Hello @HassanAbbas92 , Thanks for sharing training code imlementation, how is it going? I am reading the paper and just inference the result and think about training from scratch. Have u managed to reproduce the result?

switch626 commented 3 years ago

` import matplotlib.pyplot as plt import cv2 import sys import os from PIL import Image, ImageDraw from utils.utils import fan_NME, show_landmarks, get_preds_fromhm import numpy as np from skimage import io import shutil from torch.autograd import Variable import time import copy from torch import nn import torch import math import matplotlib matplotlib.use('Agg')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, weight_map, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(
            delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * \
            torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))

def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
                save_path='./', num_landmarks=68, start_epoch=0):
    best_acc = 100
    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=0.0000001, weight_decay=0)
    loss_AW = AdaptiveWingLoss()
    for epoch in range(start_epoch, epoches + start_epoch):
        running_loss = 0
        step = 0
        total_nme = 0
        total_count = 0
        fail_count = 0
        nmes = []
        # running_corrects = 0
        step_start = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            # with torch.set_grad_enabled(True):
            for data in dataloaders[phase]:
                optimizer.zero_grad()
                total_runtime = 0
                run_count = 0

                step += 1
                # get the inputs
                inputs = data['image'].type(torch.FloatTensor)
                labels_heatmap = data['heatmap'].type(torch.FloatTensor)
                labels_boundary = data['boundary'].type(torch.FloatTensor)
                gt_landmarks = data['landmarks'].type(torch.FloatTensor)
                loss_weight_map = data['weight_map'].type(torch.FloatTensor)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.to(device)
                    labels_heatmap = labels_heatmap.to(device)
                    labels_boundary = labels_boundary.to(device)
                    loss_weight_map = loss_weight_map.to(device)
                else:
                    inputs, labels_heatmap = Variable(
                        inputs), Variable(labels_heatmap)
                    labels_boundary = Variable(labels_boundary)
                labels = torch.cat((labels_heatmap, labels_boundary), 1)
                single_start = time.time()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs, boundary_channels = model(inputs)
                    pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
                    ###
                    loss_total = loss_AW(
                        pred_labels, loss_weight_map * 10 + 1, labels)
                    ###
                    #print("Batch Loss: {:.6f}".format(loss.item()))
                    if phase == 'train':
                        loss_total.backward()
                        optimizer.step()
                batch_nme = fan_NME(
                    outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
                #print("Batch NME: {:.6f}".format(batch_nme))
                # batch_nme = 0
                total_nme += batch_nme
            epoch_nme = total_nme / dataset_sizes[phase]
            step_end = time.time()
            print(phase + ' NME: {:.6f}'.format(epoch_nme))
            if phase == 'val' and epoch_nme < best_acc:
                state = {
                    'next_epoch': epoch+1,
                    'epoch_total_nme': epoch_nme,
                    'state_dict': model.state_dict(),
                    # 'scheduler' : scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
        #nme_save_path = os.path.join(save_path, 'nme_log.npy')
        #np.save(nme_save_path, np.array(nmes))
        #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))

    #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
    return model

` @protossw512 code you please check if my training implementation is correct

Hi. Thanks for your code and idea about the loss function. I have a question. Is there correct about the calculation method of loss1 and loss2. More specifically, * weight_map[delta_y < self.theta] is correct? Thanks. I think the torch.where is better?