jakeret / tf_unet

Generic U-Net Tensorflow implementation for image segmentation
GNU General Public License v3.0
1.9k stars 748 forks source link

Training Problem #228

Closed DoraUniApp closed 5 years ago

DoraUniApp commented 5 years ago

Hi, thanks for putting up a clean and neat implementation of u-net.

I've been playing around with your code and managed to adapt the data_provider for my multi-class problem and run the training without any errors. However, the results I'm getting from training is rather strange and not right. The training finish with what seems to be ok performance:

18:22:22,965 Iter 6397, Minibatch Loss= 0.2209, Training Accuracy= 0.9514, Minibatch error= 4.9% 18:22:23,229 Iter 6398, Minibatch Loss= 0.5043, Training Accuracy= 0.8385, Minibatch error= 16.2% 18:22:23,510 Iter 6399, Minibatch Loss= 0.1701, Training Accuracy= 0.9685, Minibatch error= 3.1% 18:22:23,511 Epoch 99, Average loss: 0.3064, learning rate: 0.0012 18:22:23,560 Verification error= 3.0%, loss= 0.1624 18:22:25,814 Optimization Finished!

but when I look at the prediction folder and the epoch images, the prediction column looks very strange for all the epochs. When I tried to do a prediction, it also came out as all blank and no meaningful results.

epoch_55

I realised some people had similar problems but then I tried to took their advice adding batch normalisation, increasing the depth, number of features, iterations and batch size but none seemed to make a difference. When I increased the batch size from default 1 to 4 the epoch images changed to a smaller window too.

epoch_0

Could this be a problem of unbalanced dataset? I feel I'm doing something wrong and was wondering if anyone can help.

jakeret commented 5 years ago

It seems like the model is not picking up any signal -a n unbalanced dataset is certainly not helping. Looks like you tried many of the common advices. I would verify that the data provider is working 100% correctly then if possible try to get some prediction only for one class to see if that works. Have you checked the tensorboard? Do the weight distributions (histograms) look "normal" i.e. not too skewed (this might indicate that something is not right with the input)? Furthermore, you could try to preprocess the data e.g. by clipping and normalizing the data or inverting the colors.

DoraUniApp commented 5 years ago

Thanks for your reply. My original data is in 3D but I slice it and feed 2D images to the network. One thing that I noticed was that when I feed in slices only with foreground pixels present (i.e. no slice with all background pixels) the network learns something and the epoch images do not the show that strange pattern. This is why I thought it might be something to due with imbalanced data. Then, I tried to modify the data provider so it would randomly return a slice with at least one of the class labels present. When I ran the training process with the modified data provider, after a few epochs the behaviour of validation images went back to that strange patterns in the first case. I think I'm missing something obvious, do you have any suggestions I could try? This is the screenshot from tensorboard tensorboard

DoraUniApp commented 5 years ago

Hey guys, can anybody help? I'm kind of stuck with this now. I can see that the problem occurs when I try to train on slices which do not have a significant number of voxels with foreground labels. I thought if tf_unet uses a sliding window approach, maybe I end up feeding patches with no labels present at all and hence the problem. Does tf_unet use a sliding window? or the entire image is fed into the network? if so, how do I control the size of that window? Any hints are much appreciated.

jakeret commented 5 years ago

Seeing what you all tried you probably are not missing anything obvious. So it's a bit hard to tell what is going wrong. The only thing that looks a bit weird is that the crossentropy is first decreasing and then going up again that is rather unusual. What is the loss function that you're using, dice?

tf_unet is a cnn, hence there is no sliding window to adjust.

Maybe a wild idea: you first try to pre-train the model on input containing the relevant voxels and when the network start to learn the relevant features you start to slowly feed images with less voxels.

DoraUniApp commented 5 years ago

Hey Jakeret, thanks for your reply. I am using dice score as the loss function. I will try your idea to see what will happen but please correct me if I'm wrong but I thought unet should see examples of both negative and positive labels to train properly. Does this mean these labels need to be present in every single batch fed into the network and it is not enough for them to be present only in the entire training dataset?

jakeret commented 5 years ago

In principle it should be sufficient in the training dataset. However, this seems not to work in this particular case for reasons not yet known. Hence, the idea to support the learning process by presenting the network more images that contain the relevant voxels in the beginning.

DoraUniApp commented 5 years ago

Hey Jakeret, I think I found the culprit! I modified your code to visualise input batches to the network and realised the strange error happens when one of the batches has no signal intensity, i.e. it's completely black. This in turn cause the normalisation code to fail because of I guess division by zero problem:

def _process_data(self, data):
#normalization

data = np.clip(np.fabs(data), self.a_min, self.a_max)
data -= np.amin(data)
if np.amax(data) != 0: data /= np.amax(data)
return data

So I added above condition to stop divisions by zero and since then it seems to be handling all type of input batches and no need for artificially balancing the class labels. Do you recommend my approach? If so, I think we can close this issue by saying that input image with no intensity was causing division by zero error and hence the rest of training issues.

jakeret commented 5 years ago

Sweet. Glad you managed to find the actual issue. Sorry for not being able to help you more. Having this division by zero sounds like a bug. Would you mind sending me a Pull Request so that we can prevent others from going thru the same process?

DoraUniApp commented 5 years ago

Sure, I will send you a pull request soon. Thanks for your help.