Open r-shruthi11 opened 5 years ago
I just realized that the documentation is not sufficiently accurate.
If n_class==2
then the implementation expects [n_samples, ny, nx, 1] and it will transform it into a one-hot encoded tensor
I see, but there is no argument that the SimpleDataProvider class takes for n_class. I believe it infers it from the size of the label (label.shape[-1])
class SimpleDataProvider(BaseDataProvider):
"""
A simple data provider for numpy arrays.
Assumes that the data and label are numpy array with the dimensions
data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where
`n` is the number of images, `X`, `Y` the size of the image.
:param data: data numpy array. Shape=[n, X, Y, channels]
:param label: label numpy array. Shape=[n, X, Y, classes]
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
"""
def __init__(self, data, label, a_min=None, a_max=None):
super(SimpleDataProvider, self).__init__(a_min, a_max)
self.data = data
self.label = label
self.file_count = data.shape[0]
self.n_class = label.shape[-1]
self.channels = data.shape[-1]
def _next_data(self):
idx = np.random.choice(self.file_count)
return self.data[idx], self.label[idx]
Yes this is correct. Then self.n_class
will define how the labels are being processed
def _process_labels(self, label):
if self.n_class == 2:
nx = label.shape[1]
ny = label.shape[0]
labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
labels[..., 1] = label
labels[..., 0] = ~label
return labels
But based on your earlier comment you mention that the implementation expects [n_samples, ny, nx, 1]. If n_class = 2 and I can't explicitly specific it then I will have to size my label array to be [n_samples, ny, nx, 2] however this becomes incompatible with what the implementation expects.
yes, you're right. Thats a bug... In the meantime you could do something like this:
data_provider = SimpleDataProvider(....)
data_provider.n_class = 2
Will do, thanks!
I'm using the SimpleDataProvider class for my training data and my labels which are numpy arrays. I initialize a network which calls on the SimpleDataProvider but however in the training process, I see an error with the shapes of the labels array. The shapes of my training data are [n_samples, ny, nx, channels] and my labels are [n_samples, ny, nx, n_class] where my n_class = 2 and I use one-hot encoding to label my foreground and background pixels.