qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.74k stars 1.03k forks source link

BatchNorm error when calling `model = sm.Unet(BACKBONE, encoder_weights='imagenet')` #478

Open xavierdcruz0 opened 3 years ago

xavierdcruz0 commented 3 years ago

I'm getting some kind of error relating to BatchNormalization as soon as I try to instantiate a Unet model:

import segmentation_models as sm
# sm.set_framework('tf.keras')
# sm.framework()
# import keras
from tensorflow import keras

print('Setting image data format...')
# keras.backend.set_image_data_format('channels_last')
keras.backend.set_image_data_format('channels_first')

print('Retrieving preprocessing function...')
BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)

print('Defining model...')
# define model
model = sm.Unet(BACKBONE, encoder_weights='imagenet')

Results in:

trainval_keras.py
2021-06-17 19:29:15.101726: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-06-17 19:29:15.101749: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
Segmentation Models: using `tf.keras` framework.
Setting image data format...
Retrieving preprocessing function...
Defining model...
Traceback (most recent call last):
  File "/home/sal9000/PycharmProjects/roi_segmentation/roi_segmentation/trainval_keras.py", line 19, in <module>
    model = sm.Unet(BACKBONE, encoder_weights='imagenet')
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/segmentation_models/__init__.py", line 34, in wrapper
    return func(*args, **kwargs)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/segmentation_models/models/unet.py", line 226, in Unet
    **kwargs,
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/segmentation_models/backbones/backbones_factory.py", line 103, in get_backbone
    model = model_fn(*args, **kwargs)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/classification_models/models_factory.py", line 78, in wrapper
    return func(*args, **new_kwargs)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/classification_models/models/resnet.py", line 321, in ResNet34
    **kwargs
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/classification_models/models/resnet.py", line 231, in ResNet
    x = layers.BatchNormalization(name='bn_data', **no_scale_bn_params)(img_input)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 970, in __call__
    input_list)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1108, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 840, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 878, in _infer_output_signature
    self._maybe_build(inputs)
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2625, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "/home/sal9000/virtualenvs/roisegenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/normalization.py", line 387, in build
    (tuple(input_shape), self.axis))
ValueError: Input has undefined `axis` dimension. Received input with shape (None, None, None, 3). Axis value: ListWrapper([1])

Packages I have installed in this virtualenv:

$ pip freeze
absl-py==0.13.0
astor==0.8.1
astunparse==1.6.3
cached-property==1.5.2
cachetools==4.2.2
certifi==2021.5.30
chardet==4.0.0
cycler==0.10.0
dataclasses==0.8
decorator==4.4.2
efficientnet==1.0.0
flatbuffers==1.12
gast==0.4.0
google-auth==1.31.0
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
grpcio==1.34.1
h5py==3.1.0
idna==2.10
image-classifiers==1.0.0
imageio==2.9.0
importlib-metadata==4.5.0
install==1.3.4
joblib==1.0.1
Keras==2.4.3
Keras-Applications==1.0.8
keras-nightly==2.5.0.dev2021032900
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.3.4
networkx==2.5.1
numpy==1.19.5
oauthlib==3.1.1
opencv-python==4.5.2.52
opt-einsum==3.3.0
pandas==1.1.5
Pillow==8.2.0
pkg-resources==0.0.0
protobuf==3.17.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
PyWavelets==1.1.1
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-image==0.17.2
scikit-learn==0.24.2
scipy==1.4.1
segmentation-models==1.0.1
six==1.15.0
synthdocs==0.1.dev0
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-estimator==2.5.0
termcolor==1.1.0
threadpoolctl==2.1.0
tifffile==2020.9.3
torch==1.8.1
torchvision==0.9.1
tqdm==4.61.1
typing-extensions==3.7.4.3
urllib3==1.26.5
Werkzeug==2.0.1
wrapt==1.12.1
zipp==3.4.1

Is it some issue of incompatable versions of TF/Keras/segmentation-models that I have?

aaprasad commented 2 years ago

did you ever figure out a solution to this problem?