athena-team / athena

an open-source implementation of sequence-to-sequence based speech processing engine
https://athena-team.readthedocs.io
Apache License 2.0
955 stars 197 forks source link

Loss is so big when I use label_smoothing=0.1. #322

Closed TeaPoly closed 4 years ago

TeaPoly commented 4 years ago

When I try aishell ASR task , the loss is bigger than 2000 when I use label smoothing equal to 0.1. Is this normal?

Some-random commented 4 years ago

It's a bug we haven't really have enough time to fix. My guess is the label smoothing implementation in tf.keras.losses.CategoricalCrossentropy has some problems and applying label smoothing on logits directly before loss calculation might just work. We will fix this ASAP.

TeaPoly commented 4 years ago

I try to add label smoothing and it looks work.

class Seq2SeqSparseCategoricalCrossentropy(tf.keras.losses.CategoricalCrossentropy):
    """ Seq2SeqSparseCategoricalCrossentropy LOSS
    CategoricalCrossentropy calculated at each character for each sequence in a batch
    """

    def __init__(self, num_classes, eos=-1, by_token=False, by_sequence=True,
                 from_logits=True, label_smoothing=0.0):
        super().__init__(from_logits=from_logits, reduction="none")
        self.by_token = by_token
        self.by_sequence = by_sequence
        self.num_classes = num_classes
        self.eos = num_classes + eos if eos < 0 else eos
        self.label_smoothing_scale = label_smoothing

    def __call__(self, logits, samples, logit_length=None):
        def apply_label_smoothing(inputs, K, epsilon=0.1):
            '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
            Args:
              inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
              epsilon: Smoothing rate.
            For example,
        import tensorflow as tf
        inputs = tf.convert_to_tensor([[[0, 0, 1],
           [0, 1, 0],
           [1, 0, 0]],
          [[1, 0, 0],
           [1, 0, 0],
           [0, 1, 0]]], tf.float32)
        outputs = label_smoothing(inputs)
        with tf.Session() as sess:
            print(sess.run([outputs]))
        >>
        [array([[[ 0.03333334,  0.03333334,  0.93333334],
            [ 0.03333334,  0.93333334,  0.03333334],
            [ 0.93333334,  0.03333334,  0.03333334]],
           [[ 0.93333334,  0.03333334,  0.03333334],
            [ 0.93333334,  0.03333334,  0.03333334],
            [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]
        ```
        '''
        return ((1 - epsilon) * inputs) + (epsilon / K)

    labels = insert_eos_in_labels(samples["output"], self.eos, samples["output_length"])
    mask = tf.math.logical_not(tf.math.equal(labels, 0))
    # labels = tf.one_hot(indices=labels, depth=self.num_classes)
    labels = apply_label_smoothing(
        tf.one_hot(indices=labels, depth=self.num_classes), 
        self.num_classes, 
        self.label_smoothing_scale)
    seq_len = tf.shape(labels)[1]
    logits = logits[:, :seq_len, :]
    loss = self.call(labels, logits)
    # loss = tf.nn.softmax_cross_entropy_with_logits(
    #     logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    if self.by_token:
        return tf.divide(tf.reduce_sum(loss), tf.reduce_sum(mask))
    if self.by_sequence:
        loss = tf.reduce_sum(loss, axis=-1)
    return tf.reduce_mean(loss)
Some-random commented 4 years ago

I try to add label smoothing and it looks work.

class Seq2SeqSparseCategoricalCrossentropy(tf.keras.losses.CategoricalCrossentropy):
    """ Seq2SeqSparseCategoricalCrossentropy LOSS
    CategoricalCrossentropy calculated at each character for each sequence in a batch
    """

    def __init__(self, num_classes, eos=-1, by_token=False, by_sequence=True,
                 from_logits=True, label_smoothing=0.0):
        super().__init__(from_logits=from_logits, reduction="none")
        self.by_token = by_token
        self.by_sequence = by_sequence
        self.num_classes = num_classes
        self.eos = num_classes + eos if eos < 0 else eos
        self.label_smoothing_scale = label_smoothing

    def __call__(self, logits, samples, logit_length=None):
        def apply_label_smoothing(inputs, K, epsilon=0.1):
            '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
            Args:
              inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
              epsilon: Smoothing rate.
            For example,
        import tensorflow as tf
        inputs = tf.convert_to_tensor([[[0, 0, 1],
           [0, 1, 0],
           [1, 0, 0]],
          [[1, 0, 0],
           [1, 0, 0],
           [0, 1, 0]]], tf.float32)
        outputs = label_smoothing(inputs)
        with tf.Session() as sess:
            print(sess.run([outputs]))
        >>
        [array([[[ 0.03333334,  0.03333334,  0.93333334],
            [ 0.03333334,  0.93333334,  0.03333334],
            [ 0.93333334,  0.03333334,  0.03333334]],
           [[ 0.93333334,  0.03333334,  0.03333334],
            [ 0.93333334,  0.03333334,  0.03333334],
            [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]
        ```
        '''
        return ((1 - epsilon) * inputs) + (epsilon / K)

    labels = insert_eos_in_labels(samples["output"], self.eos, samples["output_length"])
    mask = tf.math.logical_not(tf.math.equal(labels, 0))
    # labels = tf.one_hot(indices=labels, depth=self.num_classes)
    labels = apply_label_smoothing(
        tf.one_hot(indices=labels, depth=self.num_classes), 
        self.num_classes, 
        self.label_smoothing_scale)
    seq_len = tf.shape(labels)[1]
    logits = logits[:, :seq_len, :]
    loss = self.call(labels, logits)
    # loss = tf.nn.softmax_cross_entropy_with_logits(
    #     logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    if self.by_token:
        return tf.divide(tf.reduce_sum(loss), tf.reduce_sum(mask))
    if self.by_sequence:
        loss = tf.reduce_sum(loss, axis=-1)
    return tf.reduce_mean(loss)

Great! Could you share your results with us using label smoothing this way? I think it will get similar or slightly better accuracy but bigger loss during training and decoding.

TeaPoly commented 4 years ago

I'll give you feedback in time for the latest results.

I try to add label smoothing and it looks work.


class Seq2SeqSparseCategoricalCrossentropy(tf.keras.losses.CategoricalCrossentropy):

    """ Seq2SeqSparseCategoricalCrossentropy LOSS

    CategoricalCrossentropy calculated at each character for each sequence in a batch

    """

    def __init__(self, num_classes, eos=-1, by_token=False, by_sequence=True,

                 from_logits=True, label_smoothing=0.0):

        super().__init__(from_logits=from_logits, reduction="none")

        self.by_token = by_token

        self.by_sequence = by_sequence

        self.num_classes = num_classes

        self.eos = num_classes + eos if eos < 0 else eos

        self.label_smoothing_scale = label_smoothing

    def __call__(self, logits, samples, logit_length=None):

        def apply_label_smoothing(inputs, K, epsilon=0.1):

            '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.

            Args:

              inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.

              epsilon: Smoothing rate.

            For example,
        import tensorflow as tf

        inputs = tf.convert_to_tensor([[[0, 0, 1],

           [0, 1, 0],

           [1, 0, 0]],

          [[1, 0, 0],

           [1, 0, 0],

           [0, 1, 0]]], tf.float32)

        outputs = label_smoothing(inputs)

        with tf.Session() as sess:

            print(sess.run([outputs]))

        >>

        [array([[[ 0.03333334,  0.03333334,  0.93333334],

            [ 0.03333334,  0.93333334,  0.03333334],

            [ 0.93333334,  0.03333334,  0.03333334]],

           [[ 0.93333334,  0.03333334,  0.03333334],

            [ 0.93333334,  0.03333334,  0.03333334],

            [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]

        ```

        '''

        return ((1 - epsilon) * inputs) + (epsilon / K)

    labels = insert_eos_in_labels(samples["output"], self.eos, samples["output_length"])

    mask = tf.math.logical_not(tf.math.equal(labels, 0))

    # labels = tf.one_hot(indices=labels, depth=self.num_classes)

    labels = apply_label_smoothing(

        tf.one_hot(indices=labels, depth=self.num_classes), 

        self.num_classes, 

        self.label_smoothing_scale)

    seq_len = tf.shape(labels)[1]

    logits = logits[:, :seq_len, :]

    loss = self.call(labels, logits)

    # loss = tf.nn.softmax_cross_entropy_with_logits(

    #     logits=logits, labels=labels)

    mask = tf.cast(mask, dtype=loss.dtype)

    loss *= mask

    if self.by_token:

        return tf.divide(tf.reduce_sum(loss), tf.reduce_sum(mask))

    if self.by_sequence:

        loss = tf.reduce_sum(loss, axis=-1)

    return tf.reduce_mean(loss)

Great! Could you share your results with us using label smoothing this way? I think it will get similar or slightly better accuracy but bigger loss during training and decoding.

TeaPoly commented 4 years ago
This development data CTC greedy search and attention ground truth greedy search result (character error rate), all in 13 epochs: label_smoothing attention_weight CTC ATTENTION EPOCHS
0.1 0.7 8.20% 5.78% 13
0.1 0.5 7.91% 6.05% 13
0 0.5 8.07% 6.49% 13
Some-random commented 4 years ago

This development data CTC greedy search and attention ground truth greedy search result (character error rate), all in 13 epochs:

label_smoothing attention_weight CTC ATTENTION EPOCHS 0.1 0.7 8.20% 5.78% 13 0.1 0.5 7.91% 6.05% 13 0 0.5 8.07% 6.49% 13

Great! Can you make a pr for us? I think we should put apply_label_smoothing function outside of __call__ function and we probably don't need to have so many comments for the function (may be a paper reference and introduction of what it does is enough?)