jakeret / tf_unet

Generic U-Net Tensorflow implementation for image segmentation
GNU General Public License v3.0
1.9k stars 748 forks source link

Training with 3 classes #251

Closed I-CANT-CODE closed 5 years ago

I-CANT-CODE commented 5 years ago

Trying to train a classifier with 3 output labels- one for the border of the objects, one for the background, and one for the body of the objects. My code works on this data set when I isolate only two of these three classes and make each pixel either a 1 or 0 (like it was originally done in the paper) but it does not seem to be working when I have 3 separate channels. Convergence isn't happening and the output is completely black. The data in TRAIN_blocked_colors is just an RGB image where red blue and green are my pixel/class labels.

Any Ideas where I went wrong?

data_provider = image_util.ImageDataProvider("TRAIN_blocked_colors/*.png",data_suffix=".png",mask_suffix = "_LABEL.png") data_provider.n_class = 3 net = unet.Unet(layers=3, features_root=64, channels=1, n_class=3)

if (TRAIN==True): trainer = unet.Trainer(net,optimizer = "adam", opt_kwargs = dict(learning_rate=.0001)) path = trainer.train(data_provider, output_path="checkpoints", training_iters = 32, epochs=100)

I-CANT-CODE commented 5 years ago

Also I was previously using this code for the 2 class version for testing batches of images, but I am not so sure it will get the correct output for a 3 channel image output. I put ??? next to the part that I am not sure is correct. If anyone has any idea how I can change this code to work for a 3 channel output, essentially an RGB image let me know.

data_provider = image_util.ImageDataProvider("test_images2/*.png",data_suffix=".png",mask_suffix = "_LABEL.png")

net = unet.Unet(layers=3, features_root=64, channels=1, n_class=3)

xtest, = data_provider(1)

prediction = net.predict("checkpoints/model.ckpt", x_test)

img = util.to_rgb(prediction[...,1].reshape(-1, prediction.shape[2],1)) ???? util.save_image(img, "test_prediction.jpg")

I-CANT-CODE commented 5 years ago

Quick edit: I changed my data set. It was not one hot encoding before, but an RGB image with pixels going up to 255 in value, so I just changed it to one hot encoding like [0,0,1],[1,0,0],[0,1,0] for each pixel. The predictions each epoch aren't completely black anymore but they are still not doing very well. Any ideas?

I-CANT-CODE commented 5 years ago

Update: I fixed it a bit so now it actually seems to be learning. But my question about how to get an RGB output prediction still stands. not sure if this will work getting a colored image:

img = util.to_rgb(prediction[...,1].reshape(-1, prediction.shape[2],1)) ????

if anyone knows how to fix this that would be great. Also if anyone has experience modifying the code so that the images output each epoch in the prediction folder are RGB images so I can see how each class is doing, that would be wonderful

I-CANT-CODE commented 5 years ago

Update:

the test code didn't work code:

`data_provider = image_util.ImageDataProvider("test_images2/*.png",data_suffix=".png",mask_suffix = "_LABEL.png") data_provider.n_class = 3 net = unet.Unet(layers=3, features_root=64, channels=1, n_class=3)

xtest, = data_provider(1)

prediction = net.predict("checkpoints/model.ckpt", x_test)

img = util.to_rgb(prediction[...,1].reshape(-1, prediction.shape[2],1)) util.save_image(img, "test_prediction.jpg")`

error:

Traceback (most recent call last): File "test_unet.py", line 11, in xtest, = data_provider(1) File "/home/users/rssadre/tf_unet/tf_unet/image_util.py", line 89, in call train_data, labels = self._load_data_and_label() File "/home/users/rssadre/tf_unet/tf_unet/image_util.py", line 56, in _load_data_and_label return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class), ValueError: cannot reshape array of size 262144 into shape (1,512,512,3)

not sure how to fix this.

I-CANT-CODE commented 5 years ago

update:

So I completely changed the code for testing. When I look at the images output at each epoch in the "prediction folder" it seems to be doing well. But when I use this code to output a prediction, it looks like it is being overpowered by the first class. Any Idea what could be going wrong?

new code--------

test_input = Im.open("test_images/595.png") test_input = np.array(test_input) print(test_input.shape)

test_input = np.array(test_input) test_input = np.reshape(test_input, [1,512,512,1])

print(test_input.shape)

net = unet.Unet(layers=3, features_root=128, channels=1, n_class=3)

prediction = net.predict("checkpoints/model.ckpt",test_input)

print(prediction[0].shape)

print(prediction[0]) img = util.to_rgb(prediction[0]) util.save_image(img, "test_prediction_color.jpg")

I-CANT-CODE commented 5 years ago

Found the issue, needed to normalize input from 0 to 255 to 0 to 1.