artemmavrin / focal-loss

TensorFlow implementation of focal loss
https://focal-loss.readthedocs.io
Apache License 2.0
186 stars 43 forks source link

Unknown y_true tensor #9

Open deploy-soon opened 4 years ago

deploy-soon commented 4 years ago

Hi, I'm trying to put categorical_focal_loss in my image segmentation task. The dataset is defined with tf.data.Dataset object and the model is defined with keras Model. The model is compiled like

loss_gamma = [0.5, 1., ...]
model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=lr),
    loss=SparseCategoricalFocalLoss(gamma=loss_gamma),
...)
model.fit(...)

While training the segmentation task, assert exemption raise because the y_true tensor is Unknown. https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py#L136-L141 How do I define the true tensor? In my task, the true tensor is shaped with (BATCH, HEIGHT, WIDTH). My virtual environment is on ubuntu18.04, tensorflow 2.2.0

artemmavrin commented 4 years ago

Hi @deploy-soon can you please share a minimal example that replicates the error?

deploy-soon commented 4 years ago

These are short code for segmentation task. While generate some labels, I add some noise and random crop to images and labels, so I use map_label wrapper for dataset.

import numpy as np
import tensorflow as tf
from focal_loss import SparseCategoricalFocalLoss

def map_label(x):
    def wrapper(param):
        # add some noise with np
        field = np.zeros((640, 480))
        return field
    return tf.py_function(func=wrapper, inp=[x], Tout=tf.float32)

ipt = tf.zeros([100, 640, 480, 3], dtype=tf.dtypes.float32)
images = tf.data.Dataset.from_tensor_slices(ipt)
labels = tf.data.Dataset.range(100).map(map_label)
dataset = tf.data.Dataset.zip((images, labels)).batch(2)

images = tf.keras.Input(shape=(640, 480, 3), name="ipt")
xs = tf.keras.layers.Conv2D(20, (3, 3), padding="same")(images)
labels = tf.keras.layers.Activation("softmax", name="opt")(xs)

model = tf.keras.Model(inputs=images, outputs=labels)
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
              loss=SparseCategoricalFocalLoss(gamma=1.0),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(dataset, epochs=1)

When you run this code within tensorflow 2.2.0, you may see NotImplementedError. @artemmavrin

artemmavrin commented 4 years ago

Sorry, for the delay. I'm able to replicate your error.

It looks like TensorFlow can't infer the rank of the values in the labels dataset. SparseCategoricalFocalLoss needs the ground truth tensor rank to be statically known for its reshaping logic: https://github.com/artemmavrin/focal-loss/blob/9e023dee0a0a2b0cfde12104906917ab26e4b056/src/focal_loss/_categorical_focal_loss.py#L137-L147

A workaround that seems to work is to manually force the label shape to be known:

import numpy as np
import tensorflow as tf
from focal_loss import SparseCategoricalFocalLoss

def map_label(x):
    def wrapper(param):
        # add some noise with np
        field = np.zeros((640, 480))
        return field
    return tf.py_function(func=wrapper, inp=[x], Tout=tf.float32)

ipt = tf.zeros([100, 640, 480, 3], dtype=tf.dtypes.float32)
images = tf.data.Dataset.from_tensor_slices(ipt)
labels = tf.data.Dataset.range(100).map(map_label)
labels = labels.map(lambda label: tf.reshape(label, [640, 480]))  # New line
dataset = tf.data.Dataset.zip((images, labels)).batch(2)

images = tf.keras.Input(shape=(640, 480, 3), name="ipt")
xs = tf.keras.layers.Conv2D(20, (3, 3), padding="same")(images)
labels = tf.keras.layers.Activation("softmax", name="opt")(xs)

model = tf.keras.Model(inputs=images, outputs=labels)
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
              loss=SparseCategoricalFocalLoss(gamma=1.0),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(dataset, epochs=1)
deploy-soon commented 3 years ago

Finally, I got some hints to solve these error. Since the output of tf.py_function is not fixed, the output should set to size of input explicitly. In sparse_categorical scheme, y_true can be reshaped like below before checking the rank of true tensor.

y_true.set_shape(y_pred.get_shape()[:3])

@artemmavrin