NifTK / NiftyNet

[unmaintained] An open-source convolutional neural networks platform for research in medical image analysis and image-guided therapy
http://niftynet.io
Apache License 2.0
1.37k stars 406 forks source link

Loss doesn't converge when using WeightedSampler #89

Closed kwang-greenhand closed 5 years ago

kwang-greenhand commented 6 years ago

When dealing with a severely imbalanced class segmentation problem, I wanted to ensure that each volume sampled from the image has at least one voxel of foreground by using the WeightedSampler (instead of UniformSampler, which uniformly randomly samples from the whole image). But when I used WeightedSampler, the loss function will never converge to a reasonable value. I'm using dice_loss, and it's always above 0.2: worse than before.

What brings up a yellow flag to me is, I intentionally wanted it to overfit by using only one subject, then at least training loss should drop to a very low level. In fact I double-checked by using UniformSampler and one subject. But when I used WeightedSampler, the problem persisted. I attached the loss here. I think there might be a problem.

Please, did anyone have a similar issue? Anyone ever used the weighted sampler?

Any comment would be greatly appreciated! loss_with_one_sub

YilinLiu97 commented 6 years ago

I also have the imbalanced class problem. I used the "balanced" sampler and it works fine for me. However, I also found it hard to overfit just one subject, with the "balanced" sampler.

wyli commented 6 years ago

Hi @kwang-greenhand the curve looks nice, how about the volume level performance?

kwang-greenhand commented 6 years ago

Hi @wyli the performance is horrible. It gives me nonsense... in fact I feel like for a dice loss, training loss above 0.25 is already kind of intolerable... Below is the training (orange) and validation (blue) loss for 60 subjects. At least on training set the performance is very good: it's overfitting of course, but at least it's working. plan7_official1

@YilinLiu97 Oh is there a "balanced" sampler? Could you tell me how to call that? I may give it a shot.

Thank you guys for your reply!

mjorgecardoso commented 6 years ago

A validation loss of 0.2 is not horrible for most Segmentation problems, but I can see why that might be the case for unbalanced problems. So many problems might be explaining this. The loss itself, the augmentation, the network, the data itself... Can you please paste your config file?

Regarding the sampler, you can use the “weighted” sampler by providing a frequency sampling map as described in readthedocs.

kwang-greenhand commented 6 years ago

@mjorgecardoso I guess you didn't read my post thoroughly. The problem is, when I use weighted sampler, keep everything else unchanged, the loss will be worse than using randomly sampler, and it even cannot overfit to one subject, which gives a dice loss of 0.2.

mjorgecardoso commented 6 years ago

Weighted sampler is normally "worse" visually (when looking at training curves) as it tends to draw "harder"/complex samples.

In a highly unbalanced problem, if one draws patches uniformly, the most likely label will be background. Thus, if the learning system always predicts background, the loss function will be artificially low when averaged over multiple samples. It will look like the learning is working better, but it is not. The loss is just providing biased statistics. On the other side, a balanced sampled (e.g. weighted) will make the problem harder by drawing samples from pathological regions more often, making the average loss higher numerically, but the reported value will be less biased.

As Wenqi asked, how about the volume level performance? When you run inference on a full image and estimate the Dice of the full image (and not of random patches), is the performance still worse?

kwang-greenhand commented 6 years ago

@mjorgecardoso what you said really makes sense. But still, I checked performance, it's really bad: not even close. I tried Dice and Generalized-Dice, both didn't work. I really can't think of any possibility how this happens.

As I mentioned above, when I was using UniformSampler, I can overfit to a small training set very well, by which I mean not only the loss but performance: I modified the data-split.csv so that it uses the training set as "inference set", the results are very close to ground truth. In my opinion, with weighted sampler (weighting being binary mask), the possible number of volumes from an image is smaller than original, which should lead to a more severe overfitting, right? However it's not the truth...

At first I was worried there is problem with implementation of WeightedSampler, but after checking the code as well as outputting the volume location, it seems alright.

Did you ever encounter such a problem? Any idea what happened?

mjorgecardoso commented 6 years ago

How do your frequency/sampling maps look like? My go to strategy is to create frequency maps by estimating the volume of each class, and then set each pixel of an image to 1/Vol of the pixel’s class. Finally, Gaussian blur the result with a kernel that is half of the patch size. It might be that you are not sampling enough from backgrounds areas, which is hurting you at inference.

YilinLiu97 commented 6 years ago

for balanced sampler, you can check out the closed issue #86. @kwang-greenhand

YilinLiu97 commented 6 years ago

Hi @mjorgecardoso, I think that I have a similar issue. The segmentation result I got is resonable, by that I meant at least the predictions of ROIs look fine, but not good enough, especially for the background. My problem is also imbalanced so I used the "balanced" sample which I believe already does what you said. I also think that it might be because of sampling of the backgound. Do you think that increase the number of overall samples would help? say, 1000 --> 2000. Otherwise I would think that it was because of the fitting power of the model itself (?)

YilinLiu97 commented 6 years ago

Although I think that 1000 is already good enough for my case since I only have 3 classes...(I mean it should cover enough samples of the background (?)

mjorgecardoso commented 6 years ago

@YilinLiu97 , what do you mean by 1000 samples? do you mean 1000 samples from the same image, i.e. setting sample_per_volume=1000?

YilinLiu97 commented 6 years ago

Yes! from a single subject @mjorgecardoso

mjorgecardoso commented 6 years ago

I think that parameter is way too high. A network will spend hundreds of iterations on the same image before moving on to the next one, meaning that the final model will overfit to the last image it sees. Some of these networks don't memorise well for many iterations. I would personally use a much lower number there (say sample_per_volume=10) to avoid memorising one single dataset. In short, you want the network to see as many patients and as many classes as possible in the minimum number of iterations.

YilinLiu97 commented 6 years ago

Could you please expand a bit why some of these architectures don't memorise well for many iterations? ( I used the deepmedic architecture). Cause I've tried another network with the same dataset, same number of samples (i.e., 1000) and nearly same hyperparameters, and got pretty good result. Thanks!

mjorgecardoso commented 6 years ago

Purely to do with high dimensional optimisation. 1000 iterations given one subject will likely result in finding a local minima which is optimal for that patient, and it might be non-trivial to escape that local minima to optimise for another patient/population. If you are doing SGD or any other kind of stochastic algorithm on random samples from random patients, a good minima will have to globally more optimal for many more patients.

Imagine the problem of domain transfer. If subject 1 was acquired using an MPRAGE T1 image and subject 2 was acquired using and SPGR sequence, training for 1000 iterations on subject 1 might overspecialise the network in MPRAGE-like sequences, making it hard to perform well on SPGR sequences. You will see less of an effect if data/morphology/pathology is relatively similar between subjects, but you would definitely see an effect if they are significantly differences in the data domain.

I'm not saying that you will see significantly better results, but balanced training will likely show better generalisability and validation performance.

YilinLiu97 commented 6 years ago

Also, with only 10 samples, in my case there will only be 3 samples for each class to be trained per iteration. wouldn't that lead to the problem you said, i.e., not sampling enough from the background areas? I guess what you suggested is to sample less from one subject and compensate it by using as many subjects as possible (?). But I have only 14 subjects so this may not be an alternative.

YilinLiu97 commented 6 years ago

I see what you meant. But i think that should be solved by shuffling the data across the subjects?

mjorgecardoso commented 6 years ago

You want samples to be class-balanced and with subject variability. Sampling 3 patches per class per subject, and going through subjects quickly, will give you balanced classes and high subject/morphology variability.

When I was suggesting "sampling more from the background", it is because most people will only sample close to foreground classes, forgetting that background variability also has to be modelled.

Shuffling would work, but the way the NiftyNet sampler works means that you will sample from each subject at a time, so that will likely not solve the problem here.

kwang-greenhand commented 6 years ago

Thank you @mjorgecardoso I think I can give it a shot. Previously I tried something similar: using generalized dice loss with 1/volume weight, which is supposed to give me equal contribution from the two classes. I guess similar principle can be used for sampling too. I'll try that.

But one thing is if as you said i'm sampling foreground too much, then the model should tend to predict foreground right? But on the contrary it's giving me dominant background...

kwang-greenhand commented 6 years ago

@YilinLiu97 I think for your problem I posted something on stack-overflow https://stackoverflow.com/questions/49620177/niftynet-volume-sampling

mjorge is right. If you sample 1000 volumes per subject each time, let's say if your batch-size is 100, then the last 10 iter will be run on the same subject. Shuffling will definitely help, but in my understanding it only happens in the same queue. So if your queue length is 300, then you'll have at least 3 iter on the same subject even with shuffling (Not 100% sure I'm right but you got the idea)

kwang-greenhand commented 6 years ago

@mjorgecardoso After using your suggested weight map, the loss becomes like this... I'm using a very simple network so there shouldn't be overfitting.

It's driving me crazy... It doesn't make any sense...

image

mmodat commented 5 years ago

@mjorgecardoso Looks like this ticket can be closed. Could you confirm and do it?