jakeret / tf_unet

Generic U-Net Tensorflow implementation for image segmentation
GNU General Public License v3.0
1.9k stars 748 forks source link

problem with using own probability maps to improve segmentation #171

Open femonk opened 6 years ago

femonk commented 6 years ago

Hi,

I tried to improve the segmentation results of my UNet using individual 'weightmaps' (+) during the training. These weightmaps should increase the probabilities for some pixels of the segmentation, such that narrow edges are detected / segmented in a more reliable manner. But all my attempts to implement these weightmaps into the UNet-Code weren't successfull yet. Does anyone have some suggestions?

(+) weightmaps: I call them 'weightmaps' but maybe something like 'probabilitymap-prefactor-matrix' would describe the maps intention better.

To give some details: I've some true ground segmentations with really narrow areas like the here shown 'residue' in the uppermost corner: github_inputimg

The idea is to 'smooth' the borders of the segmented area using a gaussian filter, such that the probabilities of the pixels next to the narrow area are slightly increased what hopefully improves the correct segmentation of these small areas. The calculation of the weightmaps is defined as a function in 'training.py': def own_weightmap(input_img): [...] and the weightmap for the above shown example looks like that:

github_finalweightmap

The weightmaps are calculated in 'training.py' using this code:

#Load images
imageProvider = image_util.ImageDataProvider("D:\train\*.png", a_min=None, a_max=None, data_suffix="_gw.png", mask_suffix="_mask.png")

#calc weights
data, label = imageProvider(10)
io.savemat("./results/label.mat", mdict={"label": label})
weights = own_weightmap(label)
io.savemat("./results/weights.mat", mdict={"weights": weights})

These weightmaps (in the code-snipped called 'weights') should be feed into the network right before the assignment to the respective classes is done. In the UNet architecture of Ronneberger et al. (https://arxiv.org/abs/1505.04597) this would be 'at the last turquoise arrow' right before the '1x1 conv' and 'output segmentation map':

Yet, the calculated weightmaps are feed into the network (also in training.py) by:

# setup net & training
net = unet.Unet(layers=3, features_root=64, channels=1, n_class=2, clstm=False, cost="cross_entropy", cost_kwargs=dict(class_weights=weights))
trainer = unet.Trainer(net, optimizer="momentum")
trainer.train(imageProvider, ".\training", training_iters=1000, epochs=10)

To improve the segmentation, the predictions are pixelwise multiplied with the respective element of the weightmap. This is done in 'unet.py' in line 220-232 ('if class_weights is not None...'):

        flat_logits = tf.reshape(logits, [-1, self.n_class])
        flat_labels = tf.reshape(self.y, [-1, self.n_class])
        if cost_name == "cross_entropy":
            class_weights = cost_kwargs.pop("class_weights", None)

            if class_weights is not None:
                class_weights = tf.constant(np.array(class_weights, dtype=np.float32))
                class_weights = tf.reshape(class_weights, [-1, self.n_class])

                #weight_map = tf.multiply(flat_labels, class_weights)
                #weight_map = tf.reduce_sum(class_weights, axis=1)

                loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)

                weighted_loss = tf.multiply(loss_map, class_weights[:, 0]+class_weights[:, 1])

                loss = tf.reduce_mean(weighted_loss)

            else:
                [...]

The lines 224 & 225 (#weight_map = ...) are uncommented, since else the error 'loss=not a number' occurs.

But somehow this approach doesn't work. As a result the network is trained towards unrelevant areas as the probability map shows:

ground truth input: github_inputimg

probability map: (yellow: high values & blue: low values) github_logits5_bearb This should normally have somehow a similarity with the calculated weightmap, since it is the 'loss_map' multiplied with the weightmap (see code above). But as you see, it doesn't look like expected.

And as a result all / the most pixels are wrongly assigned to the class 'background' (= black), such that the prediction-images are all over black or only with a small area assigned to the right class:

resulting segmentation: github_labels5_bearb

Has somebody any suggestions where's or what's the problem with my code? Thanks!

jakeret commented 6 years ago

Sorry for the late reply. Maybe you have found a solution in the meantime. Anyway, I'm a bit unsure about weighted_loss = tf.multiply(loss_map, class_weights[:, 0]+class_weights[:, 1])

Also the image_util.ImageDataProvider is shuffling randomly through the training files. Do you account for that when computing the weight maps?

femonk commented 6 years ago

Hi,

thank you for your reply @jakeret. What do you mean by _'Anyway, I'm a bit unsure about weightedloss = [...]'?

As far as I see, the image_util.ImageDataProvider isn't shuffling randomly through the training files, so this shouldn't the problem. Right? einleseroutine_github

jakeret commented 6 years ago

What I meant was that I don't understand the implemenation

It's shuffling the file names.