Open deploy-soon opened 4 years ago
Hi @deploy-soon can you please share a minimal example that replicates the error?
These are short code for segmentation task. While generate some labels, I add some noise and random crop to images and labels, so I use map_label
wrapper for dataset.
import numpy as np
import tensorflow as tf
from focal_loss import SparseCategoricalFocalLoss
def map_label(x):
def wrapper(param):
# add some noise with np
field = np.zeros((640, 480))
return field
return tf.py_function(func=wrapper, inp=[x], Tout=tf.float32)
ipt = tf.zeros([100, 640, 480, 3], dtype=tf.dtypes.float32)
images = tf.data.Dataset.from_tensor_slices(ipt)
labels = tf.data.Dataset.range(100).map(map_label)
dataset = tf.data.Dataset.zip((images, labels)).batch(2)
images = tf.keras.Input(shape=(640, 480, 3), name="ipt")
xs = tf.keras.layers.Conv2D(20, (3, 3), padding="same")(images)
labels = tf.keras.layers.Activation("softmax", name="opt")(xs)
model = tf.keras.Model(inputs=images, outputs=labels)
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
loss=SparseCategoricalFocalLoss(gamma=1.0),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(dataset, epochs=1)
When you run this code within tensorflow 2.2.0, you may see NotImplementedError
.
@artemmavrin
Sorry, for the delay. I'm able to replicate your error.
It looks like TensorFlow can't infer the rank of the values in the labels
dataset. SparseCategoricalFocalLoss
needs the ground truth tensor rank to be statically known for its reshaping logic:
https://github.com/artemmavrin/focal-loss/blob/9e023dee0a0a2b0cfde12104906917ab26e4b056/src/focal_loss/_categorical_focal_loss.py#L137-L147
A workaround that seems to work is to manually force the label shape to be known:
import numpy as np
import tensorflow as tf
from focal_loss import SparseCategoricalFocalLoss
def map_label(x):
def wrapper(param):
# add some noise with np
field = np.zeros((640, 480))
return field
return tf.py_function(func=wrapper, inp=[x], Tout=tf.float32)
ipt = tf.zeros([100, 640, 480, 3], dtype=tf.dtypes.float32)
images = tf.data.Dataset.from_tensor_slices(ipt)
labels = tf.data.Dataset.range(100).map(map_label)
labels = labels.map(lambda label: tf.reshape(label, [640, 480])) # New line
dataset = tf.data.Dataset.zip((images, labels)).batch(2)
images = tf.keras.Input(shape=(640, 480, 3), name="ipt")
xs = tf.keras.layers.Conv2D(20, (3, 3), padding="same")(images)
labels = tf.keras.layers.Activation("softmax", name="opt")(xs)
model = tf.keras.Model(inputs=images, outputs=labels)
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
loss=SparseCategoricalFocalLoss(gamma=1.0),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(dataset, epochs=1)
Finally, I got some hints to solve these error. Since the output of tf.py_function
is not fixed, the output should set to size of input explicitly. In sparse_categorical scheme, y_true
can be reshaped like below before checking the rank of true tensor.
y_true.set_shape(y_pred.get_shape()[:3])
@artemmavrin
Hi, I'm trying to put
categorical_focal_loss
in my image segmentation task. The dataset is defined with tf.data.Dataset object and the model is defined with keras Model. The model is compiled likeWhile training the segmentation task, assert exemption raise because the y_true tensor is
Unknown
. https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py#L136-L141 How do I define the true tensor? In my task, the true tensor is shaped with (BATCH, HEIGHT, WIDTH). My virtual environment is on ubuntu18.04, tensorflow 2.2.0