apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.77k stars 6.8k forks source link

F.contrib.SoftmaxOHEMOutput #20067

Closed BebDong closed 3 years ago

BebDong commented 3 years ago

Description

(A clear and concise description of what the feature is.)

from gluoncv import loss as gloss

class OHEMCrossEntropyLoss(gloss.SoftmaxCrossEntropyLoss):
    """
    OHEM cross-entropy loss.
    Only support a single GPU.
    Adapted from:
        https://github.com/PaddlePaddle/PaddleSeg/blob/release/v2.0/
        paddleseg/models/losses/ohem_cross_entropy_loss.py
    """

    def __init__(self, thresh=0.7, min_kept=10000, num_classes=21, height=None, width=None,
                 crop_size=480, sparse_label=True, batch_axis=0, ignore_label=-1,
                 size_average=True, **kwargs):
        super(OHEMCrossEntropyLoss, self).__init__(sparse_label, batch_axis, ignore_label,
                                                   size_average, **kwargs)
        self._thresh = thresh
        self._min_kept = min_kept
        self._nclass = num_classes
        self._height = height if height is not None else crop_size
        self._width = width if width is not None else crop_size

    def hybrid_forward(self, F, logit, label):
        label = F.reshape(label, shape=(-1,))
        valid_mask = (label != self._ignore_label)
        num_valid = F.sum(valid_mask)
        label = label * valid_mask

        prob = F.softmax(logit, axis=1)
        prob = F.reshape(F.transpose(prob, axes=(1, 0, 2, 3)), shape=(self._nclass, -1))

        if self._min_kept < num_valid and num_valid > 0:
            # let the value which ignored greater than 1
            prob = prob + (1 - valid_mask)
            prob = F.pick(prob, label, axis=0, keepdims=False)

            threshold = self._thresh
            if self._min_kept > 0:
                index = F.argsort(prob)
                threshold_index = index[min(len(index), self._min_kept) - 1]
                threshold_index = int(threshold_index.asnumpy()[0])
                if prob[threshold_index] > self._thresh:
                    threshold = prob[threshold_index]
                kept_mask = (prob < threshold)
                label = label * kept_mask
                valid_mask = valid_mask * kept_mask

        # make the invalid region as ignore
        label = label + (1 - valid_mask) * self._ignore_label
        label = F.reshape(label, shape=(-1, self._height, self._width))
        return super(OHEMCrossEntropyLoss, self).hybrid_forward(F, logit, label)

References

github-actions[bot] commented 3 years ago

Welcome to Apache MXNet (incubating)! We are on a mission to democratize AI, and we are glad that you are contributing to it by opening this issue. Please make sure to include all the relevant context, and one of the @apache/mxnet-committers will be here shortly. If you are interested in contributing to our project, let us know! Also, be sure to check out our guide on contributing to MXNet and our development guides wiki.

szha commented 3 years ago

cc @zhreshold who maintains gluoncv