jakeret / tf_unet

Generic U-Net Tensorflow implementation for image segmentation
GNU General Public License v3.0
1.9k stars 748 forks source link

SimpleDataProvider in image_util #241

Open r-shruthi11 opened 5 years ago

r-shruthi11 commented 5 years ago

I'm using the SimpleDataProvider class for my training data and my labels which are numpy arrays. I initialize a network which calls on the SimpleDataProvider but however in the training process, I see an error with the shapes of the labels array. The shapes of my training data are [n_samples, ny, nx, channels] and my labels are [n_samples, ny, nx, n_class] where my n_class = 2 and I use one-hot encoding to label my foreground and background pixels.

snip1 snip2

jakeret commented 5 years ago

I just realized that the documentation is not sufficiently accurate. If n_class==2 then the implementation expects [n_samples, ny, nx, 1] and it will transform it into a one-hot encoded tensor

r-shruthi11 commented 5 years ago

I see, but there is no argument that the SimpleDataProvider class takes for n_class. I believe it infers it from the size of the label (label.shape[-1])

class SimpleDataProvider(BaseDataProvider):
    """
    A simple data provider for numpy arrays. 
    Assumes that the data and label are numpy array with the dimensions
    data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where
    `n` is the number of images, `X`, `Y` the size of the image.
    :param data: data numpy array. Shape=[n, X, Y, channels]
    :param label: label numpy array. Shape=[n, X, Y, classes]
    :param a_min: (optional) min value used for clipping
    :param a_max: (optional) max value used for clipping
    """

    def __init__(self, data, label, a_min=None, a_max=None):
        super(SimpleDataProvider, self).__init__(a_min, a_max)
        self.data = data
        self.label = label
        self.file_count = data.shape[0]
        self.n_class = label.shape[-1]
        self.channels = data.shape[-1]

    def _next_data(self):
        idx = np.random.choice(self.file_count)
        return self.data[idx], self.label[idx]
jakeret commented 5 years ago

Yes this is correct. Then self.n_class will define how the labels are being processed

    def _process_labels(self, label):
        if self.n_class == 2:
            nx = label.shape[1]
            ny = label.shape[0]
            labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
            labels[..., 1] = label
            labels[..., 0] = ~label
            return labels
r-shruthi11 commented 5 years ago

But based on your earlier comment you mention that the implementation expects [n_samples, ny, nx, 1]. If n_class = 2 and I can't explicitly specific it then I will have to size my label array to be [n_samples, ny, nx, 2] however this becomes incompatible with what the implementation expects.

jakeret commented 5 years ago

yes, you're right. Thats a bug... In the meantime you could do something like this:

data_provider = SimpleDataProvider(....)
data_provider.n_class = 2
r-shruthi11 commented 5 years ago

Will do, thanks!