bdzyubak / tensorflow-sandbox

A repository for studying applications of Deep Learning across fields, and demonstrating samples of my code and project managment
0 stars 0 forks source link

Implement function to add decoder side to arbitrary network #7

Open bdzyubak opened 2 years ago

bdzyubak commented 2 years ago

The number of classification architectures exceeds the number of segmentation networks. Broadly, classification tends to be shaped as a tapering pyramid which progressively reduces the x-y dimensions of the feature map while increasing the number of channels. This allows for a complex model which predicts a large number of classes based on a large spatial footprint in the image. image

To explore the benefits of sometimes very complicated, potentially machine-derived, building blocks in these networks, it would be useful to have functionality to automatically turn a classification model into a segmentation one. The simplest way to do this would be by adding a fully-connected layer at the end to map e.g. a 9x9x1400 feature map to a 260x260x1 image. However, spatial information has been lost towards the end of a classification network. This Fully Convolution Netowrk (FCN)-like approach would produce coarse segmentations: image

Instead a U-NET like architecture can be added where each downsampling (encoder) layer is matched by an upsampling (decoder) layer, and the two are concatenated (skip connection). This can be generated procedurally for many networks. The size of the resulting network will be roughly twice that of the classification network, but given that the segmenation (localization+classification) task is more complex, this may be acceptable. At minimum, such a module is useful for exploration.

bdzyubak commented 2 years ago

The current Tensorflow Image Segmentation uses this very functionality: https://www.tensorflow.org/tutorials/images/segmentation. They consider the U-NET as an overall concept consisting of an encoder, decoder, and skip connections, where the econder and decoder consist of arbitrary networks. The image segmentation example uses original U-NET to encode, and pix2pix to decode. The GAN tutorial https://www.tensorflow.org/tutorials/generative/pix2pix uses mobileNET to encode.

So, as theorized in the opening post, any classification network can be turned into a segmentation network by viewing it as an encoder. The easiest implementation would then add a fixed-architecture decoder which would procedurally match layer sizes to the encoder outputs. The more sophisticated approach would try to mirror input layers, replacing downsampling with upsampling.