p2irc / deepplantphenomics

Deep learning for plant phenotyping.
GNU General Public License v2.0
135 stars 46 forks source link

Multi-class Semantic Segmentation #46

Closed donovanlavoie closed 4 years ago

donovanlavoie commented 5 years ago

This adds the ability for the SemanticSegmentationModel to perform multi-class segmentation top of doing binary segmentation.

The number of segmentation classes can be set with a new public setter, set_num_segmentation_classes, which enforces a requirement of 2+ classes. This value is then used to properly construct the output layer so that the logits of each class in each pixel are generated as separate channels. It also helps control image reading for the labels; to preserve the integer label values with multiple classes, they are explicitly cast from uint8 to float32 instead of converted using tf.image.convert_image_dtype and having the values scaled to the 0-1 range.

Since the multiple classes are considered exclusive, multi-class segmentation required adding support for the softmax cross entropy loss function. Since the 2 supported loss functions are not strictly interchangeable, the expected number of classes is checked when calculating the prediction losses.

Inference is relatively the same with its expanded support for multiple classes, but inference outputs have a single channel; the outputs are run through a softmax and pixels are set to the highest probability class.

Old semantic segmentation examples should still run for the binary case. Meanwhile, a working multi-class test case was made by creating a copy of the existing test case and changing the masks to having the plant/foreground regions set to a mix of two classes (roughly whole vs cut-off plots, but any arbitrary distinction would work).

Other tests cases and examples should still run as well, including new tests for the new setter and the errors and loss calculation for semantic segmentation.

jubbens commented 4 years ago

Is it possible to automatically set the number of classes and corresponding loss function based on the number of unique labels at the point when they're loaded?

donovanlavoie commented 4 years ago

Setting the loss function automatically based on the class count should be easy enough. Automatically getting the class count itself is more difficult; the exact count is (only) needed when creating the output conv layer, which is done before loading and parsing the actual images (instead of just their names).

The obvious solution would be to quickly load every image to get its max value (since there's no guarantee that any given sample from the set would have the max class index), but this doesn't scale well and I'm struggling to come up with a way of getting the max value of an image without loading it.