lizhengwei1992 / Semantic_Human_Matting

Semantic Human Matting
537 stars 146 forks source link
matting segmentation

Semantic_Human_Matting

The project is my reimplement of paper (Semantatic Human Matting) from Alibaba, it proposes a new end-to-end scheme to predict human alpha from image. SHM is the first algorithm that learns to jointly fit both semantic information and high quality details with deep networks.

One of the main contributions of the paper is that: A large scale high quality human matting dataset is created. It contains 35,513 unique human images with corresponding alpha mattes. But, the dataset is not avaiable.

I collected 6k+ images as my dataset of the project. Worth noting that, the architecture of my network, which builded with mobilenet and shallow encoder-decoder net, is a light version compaired to original implement.

update 2019/04/08

:thumbsup: :thumbsup: The company 爱分割 shared their dataset recently !

Requirements

Usage

Directory structure of the project:

Semantic_Human_Matting
│   README.md
│   train.py
│   train.sh
|   test_camera.py
|   test_camera.sh
└───model
│   │   M_Net.py
│   │   T_Net.py
│   │   network.py
└───data
    │   dataset.py
    │   gen_trimap.py
    |   gen_trimap.sh
    |   knn_matting.py
    |   knn_matting.sh
    └───image
    └───mask
    └───trimap
    └───alpha

Step 1: prepare dataset

./data/train.txt contain image names according to 6k+ images(./data/image) and corresponding masks(./data/mask).

Use ./data/gen_trimap.sh to get trimaps of the masks.

Use ./data/knn_matting.sh to get alpha mattes(it will take long time...).

Step 2: build network

SHM

Step 3: build loss

The overall prediction loss for alphap at each pixel is <a href="https://www.codecogs.com/eqnedit.php?latex=\inline&space;L{p}&space;=&space;\gamma\left&space;|&space;\alpha&space;{p}&space;-&space;\alpha&space;{g}&space;\right&space;|{1}&space;+&space;\left&space;(&space;1-\gamma&space;\right&space;)\left&space;|&space;c{p}&space;-&space;c{g}&space;\right&space;|{1}" target="blank"><img src="https://latex.codecogs.com/gif.latex?\inline&space;L{p}&space;=&space;\gamma\left&space;|&space;\alpha&space;{p}&space;-&space;\alpha&space;{g}&space;\right&space;|{1}&space;+&space;\left&space;(&space;1-\gamma&space;\right&space;)\left&space;|&space;c{p}&space;-&space;c{g}&space;\right&space;|{1}" title="L{p} = \gamma\left | \alpha {p} - \alpha {g} \right |{1} + \left ( 1-\gamma \right )\left | c{p} - c{g} \right |_{1}" />

The total loss is

Read papers for more details, and my codes for two loss functions:

    # -------------------------------------
    # classification loss L_t
    # ------------------------
    criterion = nn.CrossEntropyLoss()
    L_t = criterion(trimap_pre, trimap_gt[:,0,:,:].long())

    # -------------------------------------
    # prediction loss L_p
    # ------------------------
    eps = 1e-6
    # l_alpha
    L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean()

    # L_composition
    fg = torch.cat((alpha_gt, alpha_gt, alpha_gt), 1) * img
    fg_pre = torch.cat((alpha_pre, alpha_pre, alpha_pre), 1) * img
    L_composition = torch.sqrt(torch.pow(fg - fg_pre, 2.) + eps).mean()
    L_p = 0.5*L_alpha + 0.5*L_composition

Step 4: train

Firstly, pre_train T-Net, use ./train.sh as :

python3 train.py \
    --dataDir='./data' \
    --saveDir='./ckpt' \
    --trainData='human_matting_data' \
    --trainList='./data/train.txt' \
    --load='human_matting' \
    --nThreads=4 \
    --patch_size=320 \
    --train_batch=8 \
    --lr=1e-3 \
    --lrdecayType='keep' \
    --nEpochs=1000 \
    --save_epoch=1 \
    --train_phase='pre_train_t_net'

Then, train end to end, use ./train.sh as:

python3 train.py \
    --dataDir='./data' \
    --saveDir='./ckpt' \
    --trainData='human_matting_data' \
    --trainList='./data/train.txt' \
    --load='human_matting' \
    --nThreads=4 \
    --patch_size=320 \
    --train_batch=8 \
    --lr=1e-4 \
    --lrdecayType='keep' \
    --nEpochs=2000 \
    --save_epoch=1 \
    --finetuning \
    --train_phase='end_to_end'

Test

run ./test_camera.sh