Open ldenoue opened 6 years ago
I confirm removing the Permute and changing the Reshape as I suggest generates much better masks for me:
#conv11 = Reshape((nClasses, input_height * input_width))(conv11)
#conv11 = Permute((2, 1))(conv11)
conv11 = Reshape((input_height * input_width,nClasses))(conv11)
can you please provide the link for the unet_weights.h5 file
I noticed that when I change the reshape from https://github.com/TianzhongSong/Person-Segmentation-Keras/blob/master/models/unet.py#L71
into: conv11 = Reshape((input_height * input_width, nClasses))(conv11)
the UNet seems to remove the line artifacts seen in the output masks.
Hi,I train a model with your tips, and get a results about acc=95% on the test set, but the visual results generated by predict.py
seems to be incorrect, it's very strange, and I can't find the problem point, could you give me some advice?
many appreciate for u
#unet.py
def create_unet(nClasses, input_height=256, input_width=256, nChannels=3):
s = Input(shape=(input_height, input_width, nChannels))
c1 = Conv2D(8, 3, activation='relu', padding='same') (s)
c1 = Conv2D(8, 3, activation='relu', padding='same') (c1)
p1 = MaxPooling2D() (c1)
c2 = Conv2D(16, 3, activation='relu', padding='same') (p1)
c2 = Conv2D(16, 3, activation='relu', padding='same') (c2)
p2 = MaxPooling2D() (c2)
c3 = Conv2D(32, 3, activation='relu', padding='same') (p2)
c3 = Conv2D(32, 3, activation='relu', padding='same') (c3)
p3 = MaxPooling2D() (c3)
c4 = Conv2D(64, 3, activation='relu', padding='same') (p3)
c4 = Conv2D(64, 3, activation='relu', padding='same') (c4)
p4 = MaxPooling2D() (c4)
c5 = Conv2D(128, 3, activation='relu', padding='same') (p4)
c5 = Conv2D(128, 3, activation='relu', padding='same') (c5)
u6 = Conv2DTranspose(64, 2, strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4], axis=3)
c6 = Conv2D(64, 3, activation='relu', padding='same') (u6)
c6 = Conv2D(64, 3, activation='relu', padding='same') (c6)
u7 = Conv2DTranspose(32, 2, strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3], axis=3)
c7 = Conv2D(32, 3, activation='relu', padding='same') (u7)
c7 = Conv2D(32, 3, activation='relu', padding='same') (c7)
u8 = Conv2DTranspose(16, 2, strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2], axis=3)
c8 = Conv2D(16, 3, activation='relu', padding='same') (u8)
c8 = Conv2D(16, 3, activation='relu', padding='same') (c8)
u9 = Conv2DTranspose(8, 2, strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(8, 3, activation='relu', padding='same') (u9)
c9 = Conv2D(8, 3, activation='relu', padding='same') (c9)
conv11 = Conv2D(nClasses, (1, 1), padding='same', activation='relu',
kernel_initializer=he_normal(), kernel_regularizer=l2(0.005))(c9)
conv11 = Reshape((input_height * input_width, nClasses))(conv11)
# conv11 = Permute((2, 1))(conv11)
conv11 = Activation('softmax')(conv11)
model = Model(inputs=[s], outputs=[conv11])
# outputs = Conv2D(1, 1, activation='sigmoid') (c9)
# model = Model(inputs=[s], outputs=[outputs])
return model
#predict.py
from models import unet, segnet
import cv2
import time
import numpy as np
import argparse
from utils.segdata_generator import generator
from matplotlib import pyplot as plt
from tqdm import tqdm
def predict_segmentation():
n_classes = 2
images_path = 'dataset/'
val_file = './data/seg_test.txt'
input_height = 256
input_width = 256
if args.model == 'unet':
m = unet.create_unet(n_classes, input_height=input_height, input_width=input_width)
elif args.model == 'segnet':
m = segnet.SegNet(n_classes, input_height=input_height, input_width=input_width)
else:
raise ValueError('Do not support {}'.format(args.model))
print(m.summary())
# m.load_weights("./pretrained_model/persondata_unet.h5")
m.load_weights("./weights/unet_seg_weights.h5")
m.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
colors = np.array([[0, 0, 0], [255, 255, 255]])
i = 0
t0 = time.time()
for x, y in tqdm(generator(images_path, val_file, 1, n_classes, input_height, input_width, train=False)):
pr = m.predict(x)[0]
i += 1
if i == 100:
break
print(pr.shape)
pr = pr.reshape((input_height, input_width, n_classes)).argmax(axis=2)
seg_img = np.zeros((input_height, input_width, 3))
for c in range(n_classes):
seg_img[:, :, 0] += ((pr[:, :] == c) * (colors[c][0])).astype('uint8')
seg_img[:, :, 1] += ((pr[:, :] == c) * (colors[c][1])).astype('uint8')
seg_img[:, :, 2] += ((pr[:, :] == c) * (colors[c][2])).astype('uint8')
plt.subplot(1,2,1)
plt.imshow(x[0])
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(pr)
plt.axis('off')
plt.show()
# cv2.imshow('test', seg_img)
# cv2.imwrite('.dataset/humanparsing/output/{}.jpg'.format(i), seg_img)
# i += 1
# cv2.waitKey(30)
t1 = time.time()
print(t1 - t0)
if __name__ == '__main__':
parse = argparse.ArgumentParser(description='command for training segmentation models with keras')
parse.add_argument('--model', type=str, default='unet', help='support unet, segnet')
args = parse.parse_args()
predict_segmentation()
the predict visual results:
I noticed that when I change the reshape from https://github.com/TianzhongSong/Person-Segmentation-Keras/blob/master/models/unet.py#L71
into: conv11 = Reshape((input_height * input_width, nClasses))(conv11)
the UNet seems to remove the line artifacts seen in the output masks.