Closed JikC closed 2 years ago
Which function I can replace it?
Hi,
I am sorry that MegEngine==1.6.0 does not support frequency_weighted_cross_entropy
, so I write this function based on cross_entropy
in loss.py
of the original MegEngine source code. You can add the following function in loss.py
, whose location should be like this /data/applications/miniconda3/envs/meg_brain/lib/python3.6/site-packages/megengine/functional/loss.py
. If you use the original python environment, the location should be /usr/local/lib/python3.5/dist-packages/megengine/functional/loss.py
.
from .math import sum
@_reduce_output
def frequency_weighted_cross_entropy(
pred: Tensor,
label: Tensor,
weight: Tensor = None,
axis: int = 1,
with_logits: bool = True,
label_smooth: float = 0,
reduction: str = "mean",
) -> Tensor:
n0 = pred.ndim
n1 = label.ndim
assert n0 == n1 + 1, ("target ndim must be one less than input ndim; input_ndim={} " "target_ndim={}".format(n0, n1))
if weight is not None:
weight = weight / sum(weight)
class_weight = weight[label.flatten().astype(np.int32)].reshape(label.shape)
ls = label_smooth
if with_logits:
logZ = logsumexp(pred, axis)
primary_term = indexing_one_hot(pred, label, axis)
else:
logZ = 0
primary_term = log(indexing_one_hot(pred, label, axis))
if ls is None or type(ls) in (int, float) and ls == 0:
if weight is None:
return logZ - primary_term
else:
return sum((logZ - primary_term) * class_weight, axis=1, keepdims=True) / sum(class_weight, axis=1, keepdims=True)
if not with_logits:
pred = log(pred)
if weight is None:
return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
else:
return sum((logZ - ls * pred.mean(axis) -
(1 - ls) * primary_term) * class_weight, axis=1, keepdims=True) / sum(class_weight, axis=1, keepdims=True)
I will update this note in README later.
Which function I can replace it?