keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Segmentation Models for Non-RGB Image #1081

Closed innat closed 1 year ago

innat commented 1 year ago

Short Description

I think, from current implementation, it can be used for RGB image but for GrayScale image and also for Non-RGB Multi-Channel image, it fails to build the model.

def get_arch():
    model = DeepLabV3(
        classes=21,
        backbone="resnet50_v2",
        include_rescaling=True,
        backbone_weights=None
    )
    return model

model = get_arch()
rgb_input = tf.ones(shape=(1, 100, 100, 3))
model(rgb_input ).shape # OK

model = get_arch()
gray_input = tf.ones(shape=(1, 100, 100, 1))
model(gray_input ).shape # NO

model = get_arch()
multi_channel_input = tf.ones(shape=(1, 100, 100, 4))
model(multi_channel_input ).shape # NO

Existing Implementations

  1. For gray image, maybe we can do
input = tf.ones(shape=(1, 100, 100, 1))

tf.repeat(input, repeats=3, axis=-1) # (non-trainablae)
<tf.Tensor: shape=(1, 100, 100, 3), dtype=float32
or,
tf.keras.layers.Conv2D(3, 1, 1)(input) # (trainable)
<tf.Tensor: shape=(1, 100, 100, 3), dtype=float32

Integreate any of the above in the model if input.shape[-1] == 1.

  1. For non-rgb multi-channel input tensor, maybe we can do,
def get_arch(encoder_wg='imagenet'):
    model = DeepLabV3(
        classes=21,
        backbone="resnet50_v2",
        include_rescaling=True,
        backbone_weights=encoder_wg
    )
    return model

model = get_arch(encoder_wg=None)
multi_channel_input = tf.ones(shape=(1, 100, 100, 4))
model( multi_channel_input ).shape # should build

Or, we can use

def get_arch(encoder_wg='imagenet'):
    model = DeepLabV3(
         # for input.shape[-1] > 3
         tf.keras.layers.Conv2D(3, 1, 1)(input) 
         # ....
    )
    return model

multi_channel_input = tf.ones(shape=(1, 100, 100, 4))
model = get_arch(encoder_wg='imagenet')
model( multi_channel_input ).shape # should build

@qlzh727 @tanzhenyu

DavidLandup0 commented 1 year ago

The #1119 PR includes a fix:

model = DeepLabV3(classes=25, 
                  include_rescaling=True, 
                  backbone="resnet50_v2", 
                  input_shape=(256, 256, 1), 
                  backbone_weights=None)

Then:

input_image = tf.random.uniform(shape=[2, 256, 256, 1])
output = model(input_image)

Which outputs:

{'output': <tf.Tensor: shape=(2, 256, 256, 25), dtype=float32, numpy=
 array([[[[0.04      , 0.03999998, 0.04      , ..., 0.03999998,
           0.04      , 0.03999998], ...

One thing to note is that you can't use pre-trained imagenet weights because those images were, well, colored. If you try to pass a greyscale input and load imagenet weights, a ValueError is raised:

The input shape is set up for greyscale images with one channel, but backbone weights are trained on colored images and cannot be loaded.
innat commented 1 year ago

@tanzhenyu cc. @DavidLandup0 I think PR https://github.com/keras-team/keras-cv/pull/1119 handled grayscale input but not for channel > 3.

tanzhenyu commented 1 year ago

If you want to use DeepLab with channels != 3, I guess you have to pass your own backbone anyway. I'm not sure if DeepLab itself makes any assumption about the channel?

DavidLandup0 commented 1 year ago

Adjust the backbone, DeepLab is fully agnostic to the input:

backbone = keras_cv.models.ResNet50V2(include_rescaling=True,
                                      stackwise_dilations=[1, 1, 1, 2],
                                      input_shape=(512, 512, 4), 
                                      include_top=False, 
                                      weights=None)

from keras_cv.models.segmentation.deeplab import DeepLabV3

model = DeepLabV3(classes=21,
                  backbone=backbone,
                  weights=None)
input_image = tf.random.uniform(shape=[1, 512, 512, 4])
output = model(input_image)
{'output': <tf.Tensor: shape=(1, 512, 512, 21), dtype=float32, numpy=
 array([[[[9.9327528e-01, 3.1995063e-05, 4.1178692e-04, ...,
DavidLandup0 commented 1 year ago

Naturally, you can't re-use weights if channels are not 3, because they were both trained on 3 channels

innat commented 1 year ago

@DavidLandup0 Thanks for the confirmation. @tanzhenyu Closing it.