qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.7k stars 1.03k forks source link

Training with non RGB images #405

Open eddienko opened 3 years ago

eddienko commented 3 years ago

I am trining the network with gray scale images and I found an issue (I believe) with the documentation. The docs say:

from segmentation_models import Unet
from keras.layers import Input, Conv2D
from keras.models import Model

# read/scale/preprocess data
x, y = ...

# define number of channels
N = x.shape[-1]

base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet')

inp = Input(shape=(None, None, N))
l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
out = base_model(l1)

model = Model(inp, out, name=base_model.name)

# continue with usual steps: compile, fit, etc..

This does not work for me when using the tensorflow backend. Instead I need to do:

from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model
JordanMakesMaps commented 3 years ago

What doesn't work for you exactly?

eddienko commented 3 years ago

Sorry, my bad, it does work as written. I have an extra line in the code:

import segmentation_models as sm
sm.set_framework('tf.keras') 

In this case the code in the docs produce the following error message:

...
AttributeError                            Traceback (most recent call last)
<ipython-input-6-e0b3e3f58b6b> in <module>
      3 inp = Input(shape=(None, None, N))
      4 l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
----> 5 out = base_model(l1)
      6 
      7 model = Model(inp, out, name=base_model.name)

...

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in <lambda>(t)
   2327             `call` method of the layer at the call that created the node.
   2328     """
-> 2329     inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
   2330                                         input_tensors)
   2331     node_indices = nest.map_structure(lambda t: t._keras_history.node_index,

AttributeError: 'tuple' object has no attribute 'layer'
LAUBENicolas commented 3 years ago

You can also artificially create a 3D image from your 2D image with ìmage_3d = np.array([image, image, image])