Closed innat closed 1 year ago
The #1119 PR includes a fix:
model = DeepLabV3(classes=25,
include_rescaling=True,
backbone="resnet50_v2",
input_shape=(256, 256, 1),
backbone_weights=None)
Then:
input_image = tf.random.uniform(shape=[2, 256, 256, 1])
output = model(input_image)
Which outputs:
{'output': <tf.Tensor: shape=(2, 256, 256, 25), dtype=float32, numpy=
array([[[[0.04 , 0.03999998, 0.04 , ..., 0.03999998,
0.04 , 0.03999998], ...
One thing to note is that you can't use pre-trained imagenet weights because those images were, well, colored. If you try to pass a greyscale input and load imagenet weights, a ValueError
is raised:
The input shape is set up for greyscale images with one channel, but backbone weights are trained on colored images and cannot be loaded.
@tanzhenyu cc. @DavidLandup0 I think PR https://github.com/keras-team/keras-cv/pull/1119 handled grayscale input but not for channel > 3.
If you want to use DeepLab with channels != 3, I guess you have to pass your own backbone anyway. I'm not sure if DeepLab itself makes any assumption about the channel?
Adjust the backbone, DeepLab is fully agnostic to the input:
backbone = keras_cv.models.ResNet50V2(include_rescaling=True,
stackwise_dilations=[1, 1, 1, 2],
input_shape=(512, 512, 4),
include_top=False,
weights=None)
from keras_cv.models.segmentation.deeplab import DeepLabV3
model = DeepLabV3(classes=21,
backbone=backbone,
weights=None)
input_image = tf.random.uniform(shape=[1, 512, 512, 4])
output = model(input_image)
{'output': <tf.Tensor: shape=(1, 512, 512, 21), dtype=float32, numpy=
array([[[[9.9327528e-01, 3.1995063e-05, 4.1178692e-04, ...,
Naturally, you can't re-use weights if channels are not 3, because they were both trained on 3 channels
@DavidLandup0 Thanks for the confirmation. @tanzhenyu Closing it.
Short Description
I think, from current implementation, it can be used for RGB image but for GrayScale image and also for Non-RGB Multi-Channel image, it fails to build the model.
Existing Implementations
Integreate any of the above in the model if
input.shape[-1] == 1
.Or, we can use
@qlzh727 @tanzhenyu