Golbstein / Keras-segmentation-deeplab-v3.1

An awesome semantic segmentation model that runs in real time
175 stars 73 forks source link

Train on a custom dataset #5

Open wave-transmitter opened 5 years ago

wave-transmitter commented 5 years ago


is it possible to train that model in a different dataset than VOC or at least fine-tune it? Have you tried something similar?

As mentioned in this repository, there is a problem with Keras implementations and deeplab model training/fine-tuning.

Golbstein commented 5 years ago

Hi, in your mentioned repository they claim that they did successfully trained this model. I've also fine-tuned it on Berkeley deepDrive.

If you've concerns - I suggest you to take a subset of ~100 images from your new dataset and try to overfit their segmentation masks to be confident that it works.

Another issue that the '0' label here is "background"... Some people say it's not so smart to use background as a label, and some datasets, like DeepDrive don't even have this class, so you may want to re-train on VOC and change all background labels to be "void" label

swarmt commented 5 years ago

I have modified the values of PASCAL_VOC_classes in utils.py with 16 new values for my own dataset.

When I try to train I get the error: tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 0 in both shapes must be equal, but are 1 and 21. Shapes are [1,1,256,16] and [21,256,1,1]. for 'Assign_270' (op: 'Assign') with input shapes: [1,1,256,16], [21,256,1,1].

Which I find strange. The tensor order appears backwards as well as mismatched?

lattard commented 4 years ago

@swarmt did you manage to solve this problem please ? I am trying to retrain the network with just one class label and I'm having the same error that you mentioned.

curryJ commented 3 years ago

could you tell me your keras version and tf version?Thanks a lot!

Mps24-7uk commented 3 years ago

@swarmt If you check the masks, it has a label ID from 0 to 21. Since you want to train the model for 16 labels, so change the label Id for excluded classes as void i.e 21. For example np.where(mask==19,21,mask) . Or you can only remove those images and masks pair having a label ID you want to exclude. Download the mask dataset for Pascal Voc from here https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip