TianzhongSong / Person-Segmentation-Keras

Person segmentation with Keras (SegNet, Unet, etc.)
Apache License 2.0
160 stars 34 forks source link

reshape to get sharp masks #4

Open ldenoue opened 6 years ago

ldenoue commented 6 years ago

I noticed that when I change the reshape from https://github.com/TianzhongSong/Person-Segmentation-Keras/blob/master/models/unet.py#L71

into: conv11 = Reshape((input_height * input_width, nClasses))(conv11)

the UNet seems to remove the line artifacts seen in the output masks.

ldenoue commented 6 years ago

I confirm removing the Permute and changing the Reshape as I suggest generates much better masks for me:

    #conv11 = Reshape((nClasses, input_height * input_width))(conv11)
    #conv11 = Permute((2, 1))(conv11)
    conv11 = Reshape((input_height * input_width,nClasses))(conv11)
nayanleo commented 5 years ago

can you please provide the link for the unet_weights.h5 file

minushuang commented 2 years ago

I noticed that when I change the reshape from https://github.com/TianzhongSong/Person-Segmentation-Keras/blob/master/models/unet.py#L71

into: conv11 = Reshape((input_height * input_width, nClasses))(conv11)

the UNet seems to remove the line artifacts seen in the output masks.

Hi,I train a model with your tips, and get a results about acc=95% on the test set, but the visual results generated by predict.py seems to be incorrect, it's very strange, and I can't find the problem point, could you give me some advice?

many appreciate for u

#unet.py
def create_unet(nClasses, input_height=256, input_width=256, nChannels=3):
    s = Input(shape=(input_height, input_width, nChannels))
    c1 = Conv2D(8, 3, activation='relu', padding='same') (s)
    c1 = Conv2D(8, 3, activation='relu', padding='same') (c1)
    p1 = MaxPooling2D() (c1)
    c2 = Conv2D(16, 3, activation='relu', padding='same') (p1)
    c2 = Conv2D(16, 3, activation='relu', padding='same') (c2)
    p2 = MaxPooling2D() (c2)
    c3 = Conv2D(32, 3, activation='relu', padding='same') (p2)
    c3 = Conv2D(32, 3, activation='relu', padding='same') (c3)
    p3 = MaxPooling2D() (c3)
    c4 = Conv2D(64, 3, activation='relu', padding='same') (p3)
    c4 = Conv2D(64, 3, activation='relu', padding='same') (c4)
    p4 = MaxPooling2D() (c4)
    c5 = Conv2D(128, 3, activation='relu', padding='same') (p4)
    c5 = Conv2D(128, 3, activation='relu', padding='same') (c5)
    u6 = Conv2DTranspose(64, 2, strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4], axis=3)
    c6 = Conv2D(64, 3, activation='relu', padding='same') (u6)
    c6 = Conv2D(64, 3, activation='relu', padding='same') (c6)
    u7 = Conv2DTranspose(32, 2, strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3], axis=3)
    c7 = Conv2D(32, 3, activation='relu', padding='same') (u7)
    c7 = Conv2D(32, 3, activation='relu', padding='same') (c7)
    u8 = Conv2DTranspose(16, 2, strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2], axis=3)
    c8 = Conv2D(16, 3, activation='relu', padding='same') (u8)
    c8 = Conv2D(16, 3, activation='relu', padding='same') (c8)
    u9 = Conv2DTranspose(8, 2, strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(8, 3, activation='relu', padding='same') (u9)
    c9 = Conv2D(8, 3, activation='relu', padding='same') (c9)

    conv11 = Conv2D(nClasses, (1, 1), padding='same', activation='relu',
                   kernel_initializer=he_normal(), kernel_regularizer=l2(0.005))(c9)
    conv11 = Reshape((input_height * input_width, nClasses))(conv11)
    # conv11 = Permute((2, 1))(conv11)
    conv11 = Activation('softmax')(conv11)
    model = Model(inputs=[s], outputs=[conv11])

    # outputs = Conv2D(1, 1, activation='sigmoid') (c9)
    # model = Model(inputs=[s], outputs=[outputs])
    return model 
#predict.py
from models import unet, segnet
import cv2
import time
import numpy as np
import argparse
from utils.segdata_generator import generator
from matplotlib import pyplot as plt
from tqdm import tqdm

def predict_segmentation():
    n_classes = 2
    images_path = 'dataset/'
    val_file = './data/seg_test.txt'
    input_height = 256
    input_width = 256

    if args.model == 'unet':
        m = unet.create_unet(n_classes, input_height=input_height, input_width=input_width)
    elif args.model == 'segnet':
        m = segnet.SegNet(n_classes, input_height=input_height, input_width=input_width)
    else:
        raise ValueError('Do not support {}'.format(args.model))

    print(m.summary())
    # m.load_weights("./pretrained_model/persondata_unet.h5")
    m.load_weights("./weights/unet_seg_weights.h5")
    m.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

    colors = np.array([[0, 0, 0], [255, 255, 255]])
    i = 0
    t0 = time.time()
    for x, y in tqdm(generator(images_path, val_file, 1, n_classes, input_height, input_width, train=False)):
        pr = m.predict(x)[0]
        i += 1
        if i == 100:
            break
        print(pr.shape)
        pr = pr.reshape((input_height, input_width, n_classes)).argmax(axis=2)
        seg_img = np.zeros((input_height, input_width, 3))
        for c in range(n_classes):
            seg_img[:, :, 0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8')
            seg_img[:, :, 1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8')
            seg_img[:, :, 2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8')

        plt.subplot(1,2,1)
        plt.imshow(x[0])
        plt.axis('off')

        plt.subplot(1,2,2)
        plt.imshow(pr)
        plt.axis('off')

        plt.show()
        # cv2.imshow('test', seg_img)
        # cv2.imwrite('.dataset/humanparsing/output/{}.jpg'.format(i), seg_img)
        # i += 1
        # cv2.waitKey(30)
    t1 = time.time()
    print(t1 - t0)

if __name__ == '__main__':
    parse = argparse.ArgumentParser(description='command for training segmentation models with keras')
    parse.add_argument('--model', type=str, default='unet', help='support unet, segnet')
    args = parse.parse_args()
    predict_segmentation()

the predict visual results: image