hellochick / PSPNet-tensorflow

TensorFlow-based implementation of "Pyramid Scene Parsing Network".
326 stars 123 forks source link

Issue with the tutorial (https://modeldepot.io/hellochick/pspnet) - shapeless image #37

Open PascalWB opened 6 years ago

PascalWB commented 6 years ago

Thank you for all the input. I have, however, problems to get the tutorial (https://modeldepot.io/hellochick/pspnet) running on my machine. Unfortunately, I'm very new to python.

Minimal example: import tensorflow as tf import numpy as np from scipy import misc import matplotlib.pyplot as plt

from PIL import Image

from model import PSPNet101, PSPNet50 from tools import *

image_path = './input/test1.png'

img_np, filename = load_img(image_path) plt.imshow(img_np)


This leads to an error message:

_img_np, filename = load_img(image_path) plt.imshow(img_np) successful load img: ./input/test1.png Traceback (most recent call last):

File "", line 2, in plt.imshow(img_np)

File "C:\Anaconda3\lib\site-packages\matplotlib\pyplot.py", line 3080, in imshow **kwargs)

File "C:\Anaconda3\lib\site-packages\matplotlib__init__.py", line 1710, in inner return func(ax, *args, **kwargs)

File "C:\Anaconda3\lib\site-packages\matplotlib\axes_axes.py", line 5194, in imshow im.set_data(X)

File "C:\Anaconda3\lib\site-packages\matplotlib\image.py", line 600, in set_data raise TypeError("Image data cannot be converted to float")

TypeError: Image data cannot be converted to float_


The problem seems to be that img_np is shapeless: _imgnp Out[3]: <tf.Tensor 'DecodePng:0' shape=(?, ?, 3) dtype=uint8>


This issue is well discussed (https://github.com/tensorflow/tensorflow/issues/9356), but I could not figure out the solution - does anybody has a hint?

I'd like to read in images of different sizes.

All the best, Pascal

PascalWB commented 6 years ago

For anyone facing issues with the tutorial:

There are a couple of code lines missing. You have to run the session before plotting the images.

Insert:

Run and get result image

    preds = sess.run(pred)
    data = sess.run(img_np)

just before plotting.