MegEngine / OMNet

OMNet: Learning Overlapping Mask for Partial-to-Partial Point Cloud Registration, ICCV 2021, MegEngine implementation.
MIT License
39 stars 4 forks source link

AttributeError: module 'megengine.functional.nn' has no attribute 'frequency_weighted_cross_entropy' #6

Closed JikC closed 2 years ago

JikC commented 2 years ago

Which function I can replace it?

hxwork commented 2 years ago

Which function I can replace it?

Hi,

I am sorry that MegEngine==1.6.0 does not support frequency_weighted_cross_entropy, so I write this function based on cross_entropy in loss.py of the original MegEngine source code. You can add the following function in loss.py, whose location should be like this /data/applications/miniconda3/envs/meg_brain/lib/python3.6/site-packages/megengine/functional/loss.py. If you use the original python environment, the location should be /usr/local/lib/python3.5/dist-packages/megengine/functional/loss.py.

from .math import sum

@_reduce_output
def frequency_weighted_cross_entropy(
    pred: Tensor,
    label: Tensor,
    weight: Tensor = None,
    axis: int = 1,
    with_logits: bool = True,
    label_smooth: float = 0,
    reduction: str = "mean",
) -> Tensor:

    n0 = pred.ndim
    n1 = label.ndim
    assert n0 == n1 + 1, ("target ndim must be one less than input ndim; input_ndim={} " "target_ndim={}".format(n0, n1))

    if weight is not None:
        weight = weight / sum(weight)
        class_weight = weight[label.flatten().astype(np.int32)].reshape(label.shape)

    ls = label_smooth

    if with_logits:
        logZ = logsumexp(pred, axis)
        primary_term = indexing_one_hot(pred, label, axis)
    else:
        logZ = 0
        primary_term = log(indexing_one_hot(pred, label, axis))
    if ls is None or type(ls) in (int, float) and ls == 0:
        if weight is None:
            return logZ - primary_term
        else:
            return sum((logZ - primary_term) * class_weight, axis=1, keepdims=True) / sum(class_weight, axis=1, keepdims=True)
    if not with_logits:
        pred = log(pred)
    if weight is None:
        return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
    else:
        return sum((logZ - ls * pred.mean(axis) -
                    (1 - ls) * primary_term) * class_weight, axis=1, keepdims=True) / sum(class_weight, axis=1, keepdims=True)

I will update this note in README later.