deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.08k stars 650 forks source link

LogSumExp seems to be missing in SoftmaxCrossEntropyLoss (sparseLabel=false) #520

Closed enpasos closed 3 years ago

enpasos commented 3 years ago

Description

On latest MasterBranch (commit 4b516196) in SoftmaxCrossEntropyLoss (sparseLabel=false) line 85

            loss = pred.mul(lab).neg().sum(new int[] {classAxis}, true);

looks like the LogSumExp term (red) is missing. grafik

Proposed correction

            int[] axes = new int[] {classAxis};
            NDArray max = pred.max(axes, true);
            NDArray logSumExp = max.add((pred.sub(max)).exp().sum(axes, true).log());
            loss = logSumExp.sub(pred.mul(lab).sum(axes, true));
enpasos commented 3 years ago

or maybe there is a more clever way to add it as the derivative of the term is known from forward propagation already

enpasos commented 3 years ago

maybe more stable grafik

            int[] axes = new int[]{classAxis};
            NDArray max = pred.max(axes, true);
            NDArray predSubMax = pred.sub(max);
            loss = predSubMax.exp().sum(axes, true).log().sub(predSubMax.mul(lab).sum(axes, true));
enpasos commented 3 years ago

maybe more stable ...

works well for me.

roywei commented 3 years ago

Hi @enpasos, I think the logsumexp is handled in C++ backend in the logSoftmax operator before calculating the loss.

We are trying to be consistent with PyTorch and MXNet API of CrossEntropyLoss.

You can follow the definition and equation here in PyTorch's documentation: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropy#torch.nn.CrossEntropyLoss

You can also refer to our book: https://d2l.djl.ai/chapter_linear-networks/softmax-regression-djl.html#the-softmax

You can see the results are consistent:

PyTorch and MXNet Python API

import torch
from torch.nn import CrossEntropyLoss

import mxnet as mx
from mxnet.gluon.loss import SoftmaxCELoss

loss = CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

loss_mx = SoftmaxCELoss()
input_mx = mx.nd.array(input.detach().numpy())
target_mx = mx.nd.array(target.detach().numpy())
output_mx = loss_mx(input_mx, target_mx).mean()

output
tensor(1.4238, grad_fn=<NllLossBackward>)

output_mx

[1.4238214]
<NDArray 1 @cpu(0)>

DJL:

    public void softmaxCrossEntropyTest() {
        try (NDManager manager = NDManager.newBaseManager()) {

            NDArray pred = manager.create(new float[] {-0.6768f, -1.1598f,  1.4133f,  1.3114f,  0.3797f,
                    0.6835f, -1.2555f, -0.7368f,  1.0889f, -1.0609f,
                    1.5516f, -2.0066f,  0.9542f,  0.7195f, -1.2068f}, new Shape(3,5));
            NDArray label = manager.create(new float[] {3, 2, 0}, new Shape(3));
            NDArray output = Loss.softmaxCrossEntropyLoss().evaluate(new NDList(label), new NDList(pred));
            Assertions.assertAlmostEquals(
                    output,
                    manager.create(1.4238214f));
            System.out.println(output);
        }
    }

ND: () cpu() float32
1.4238
roywei commented 3 years ago

However, we do not have the logsumexp operator directly exposed for users to use. PyTorch has this operator, MXNet does not. Feel free to add this operator and open a PR.

For the MXNet part, you can use your implementation on Java side, on PyTorch you can call the operator directly. Here is a sample PR adding a new operator: https://github.com/awslabs/djl/pull/579/files

I'm closing this issue as it's not a bug, please open new issues or ask in our slack channel if you face any problem adding operators. Thanks!

enpasos commented 3 years ago

Hi @roywei, thank you for looking into the issue. I suggest to reopen the issue to correct the "sparseLabel=false" part of the loss function.

Your unit test investigates the method "evaluate" in SoftmaxCrossEntropyLoss in a different context

fromLogit = false
sparseLabel = true

and does not test the problematic code part I was dealing with.

My context is

fromLogit = true
sparseLabel = false  // this is especially important, and maybe just a situation not concerned in the tests sofar

As you suggested let's look at an example from the python apis as reference for a unit test:

import tensorflow as tf

logitsInput = [[-8.0, 10.0,  3.0],  [1.0, 2.0, 3.0]]
target = [[0.3, 0.3, 0.4], [0.1, 0.1, 0.8]]

output = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=logitsInput)
print(output)

gives out

tf.Tensor([8.2009115  0.70760596], shape=(2,), dtype=float32)

The unitTest with the current DJL code fails:

@Test
    public void softmaxCrossEntropyNonSparseLabelsFailTest() {
        boolean fromLogit = true;
        boolean sparseLabel = false;

        try (NDManager manager = NDManager.newBaseManager()) {

            NDArray pred = manager.create(new float[] {-8.0f, 10.0f, 3.0f, 1.0f, 2.0f, 3.0f}, new Shape(2,3));
            NDArray label = manager.create(new float[] {0.3f, 0.3f, 0.4f, 0.1f, 0.1f, 0.8f}, new Shape(2,3));
            NDArray expected = manager.create(new float[] {8.2009115f,  0.70760596f}).mean();

            NDArray output = Loss.softmaxCrossEntropyLoss("SoftmaxCrossEntropyLoss", 1, 1, sparseLabel, fromLogit).evaluate(new NDList(label), new NDList(pred));

            Assertions.assertAlmostEquals(output, expected);

        }
    }

The unit test for my corrected code passes:

 @Test
    public void softmaxCrossEntropyNonSparseLabelsPassTest() {
        boolean fromLogit = true;
        boolean sparseLabel = false;

        try (NDManager manager = NDManager.newBaseManager()) {

            NDArray pred = manager.create(new float[] {-8.0f, 10.0f, 3.0f, 1.0f, 2.0f, 3.0f}, new Shape(2,3));
            NDArray label = manager.create(new float[] {0.3f, 0.3f, 0.4f, 0.1f, 0.1f, 0.8f}, new Shape(2,3));
            NDArray expected = manager.create(new float[] {8.2009115f,  0.70760596f}).mean();

            NDArray output = new MySoftmaxCrossEntropyLoss("SoftmaxCrossEntropyLoss", 1, 1, sparseLabel, fromLogit).evaluate(new NDList(label), new NDList(pred));

            Assertions.assertAlmostEquals(output, expected);

        }
    }

I have adjusted the loss function in the following way

public class MySoftmaxCrossEntropyLoss extends Loss {

    private float weight;
    private int classAxis;
    private boolean sparseLabel;
    private boolean fromLogit;

    /**
     * Creates a new instance of {@code SoftmaxCrossEntropyLoss} with default parameters.
     */
    public MySoftmaxCrossEntropyLoss() {
        this("SoftmaxCrossEntropyLoss");
    }

    /**
     * Creates a new instance of {@code SoftmaxCrossEntropyLoss} with default parameters.
     *
     * @param name the name of the loss
     */
    public MySoftmaxCrossEntropyLoss(String name) {
        this(name, 1, -1, true, false);
    }

    /**
     * Creates a new instance of {@code SoftmaxCrossEntropyLoss} with the given parameters.
     *
     * @param name        the name of the loss
     * @param weight      the weight to apply on the loss value, default 1
     * @param classAxis   the axis that represents the class probabilities, default -1
     * @param sparseLabel whether labels are integer array or probabilities, default true
     * @param fromLogit   whether predictions are log probabilities or un-normalized numbers, default
     *                    false
     */
    public MySoftmaxCrossEntropyLoss(
            String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
        super(name);
        this.weight = weight;
        this.classAxis = classAxis;
        this.sparseLabel = sparseLabel;
        this.fromLogit = fromLogit;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public NDArray evaluate(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        if (!fromLogit) {
            pred = pred.logSoftmax(classAxis);
        }
        NDArray loss;
        NDArray lab = label.singletonOrThrow();
        if (sparseLabel) {
            NDIndex pickIndex =
                    new NDIndex()
                            .addAllDim(Math.floorMod(classAxis, pred.getShape().dimension()))
                            .addPickDim(lab);
            loss = pred.get(pickIndex).neg();
        } else {
            lab = lab.reshape(pred.getShape());
            int[] axes = new int[]{classAxis};
            NDArray max = pred.max(axes, true);
            NDArray predSubMax = pred.sub(max);
            loss = predSubMax.exp().sum(axes, true).log().sub(predSubMax.mul(lab).sum(axes, true));
        }
        if (weight != 1) {
            loss = loss.mul(weight);
        }
        return loss.mean();
    }
} 
roywei commented 3 years ago

Hi @enpasos,

Taking another look, I think the fromLogits definition of MXNet and TensorFlow are different

So in the case of MXNet, you will need to manually apply log_softmaxif you set fromLogits=True.

Reference: https://github.com/apache/incubator-mxnet/issues/12185 https://discuss.pytorch.org/t/pytorch-equivalence-to-sparse-softmax-cross-entropy-with-logits-in-tensorflow/18727/3

non-sparse labels can be converted to sparse by argmax, so they should produce the same result.

Here are the python code that will generate same results in all 3 frameworks. For consistency, we will change the behavior the same as TF and update our documentation.

import torch
import mxnet as mx
from mxnet.gluon.loss import SoftmaxCELoss
import tensorflow as tf

# tensorflow
logitsInput = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
target = [[1, 0, 0], [0, 1, 0]]

# logits=True, sparse=False
tf_loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=logitsInput)
print(tf_loss)

# logits=True, sparse=True
tf_loss_sparse = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=[0, 1], logits=logitsInput)
print(tf_loss_sparse)

# pytorch
import torch.nn.functional as F
# logits=True, sparse=True
pt_loss = F.nll_loss(F.log_softmax(torch.tensor(logitsInput)), torch.tensor([0, 1]), reduction='none')
print(pt_loss)

# mxnet
# logits=True, sparse=False
loss_mx = SoftmaxCELoss(from_logits=True, sparse_label=False)
input_mx = mx.nd.array(logitsInput)
target_mx = mx.nd.array(target)
mx_loss = loss_mx(mx.nd.log_softmax(input_mx), target_mx)
print(mx_loss)

# logits=True, sparse=True
loss_mx = SoftmaxCELoss(from_logits=True, sparse_label=True)
input_mx = mx.nd.array(logitsInput)
target_mx = mx.nd.array([0, 1])
mx_loss_sparse = loss_mx(mx.nd.log_softmax(input_mx), target_mx)
print(mx_loss_sparse)
tf.Tensor([0.16984604 0.02474492], shape=(2,), dtype=float32)
tf.Tensor([0.16984604 0.02474492], shape=(2,), dtype=float32)
tensor([0.1698, 0.0247])

[0.16984604 0.02474492]
<NDArray 2 @cpu(0)>

[0.16984604 0.02474492]
<NDArray 2 @cpu(0)>
enpasos commented 3 years ago

Hi @roywei, your comment and the pull request "fix softmax flag" make the behaviour clear. Now, I understand the approach :-) Thanks a lot.

To be solidly sure, that I get the correct gradients from the loss function I compared the gradients "handcalculated" against "autograd", once for tensorflow and once for DJL on mxnet (see below). For tensorflow the gradients are the same - on DJL they seam to differ by a factor two. Please, check if you could see the cause - hopefully something wrong in my test ... then sorry for the noise.

Tensorflow

import tensorflow as tf

logits = [[-8.0, 10.0,  3.0],  [1.0, 2.0, 3.0]]
target = [[0.3, 0.3, 0.4], [0.1, 0.1, 0.8]]

probabilities = tf.nn.softmax(
    logits, axis=None, name=None
)
print("probabilities: ")
print(probabilities)

handcalculatedGradient = probabilities - target
print("handcalculated gradient (at logits): ")
print(handcalculatedGradient)

o = tf.Variable(logits)
with tf.GradientTape() as tape:
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=o)
print("loss: ")
print(loss)

dl_do = tape.gradient(loss, o)
print("autodiff gradient (at logits): ")
print(dl_do)

print("are handcalculated gradient and autodiff gradient equal?")
print(tf.math.equal(handcalculatedGradient , dl_do, name=None))

give

probabilities: 
tf.Tensor(
[[1.5216104e-08 9.9908900e-01 9.1105123e-04]
 [9.0030573e-02 2.4472848e-01 6.6524094e-01]], shape=(2, 3), dtype=float32)
handcalculated gradient (at logits): 
tf.Tensor(
[[-0.29999998  0.699089   -0.39908895]
 [-0.00996943  0.14472848 -0.13475907]], shape=(2, 3), dtype=float32)
loss: 
tf.Tensor([8.2009115  0.70760596], shape=(2,), dtype=float32)
autodiff gradient (at logits): 
tf.Tensor(
[[-0.29999998  0.699089   -0.39908895]
 [-0.00996943  0.14472848 -0.13475907]], shape=(2, 3), dtype=float32)
are handcalculated gradient and autodiff gradient equal?
tf.Tensor(
[[ True  True  True]
 [ True  True  True]], shape=(2, 3), dtype=bool)

DJL on mxnet

    @Test
    public void softmaxCrossEntropyNonSparseLabelsTest() {
        boolean fromLogit = true;
        boolean sparseLabel = false;

        try (NDManager manager = NDManager.newBaseManager()) {

            NDArray logits = manager.create(new float[] {-8.0f, 10.0f, 3.0f, 1.0f, 2.0f, 3.0f}, new Shape(2,3));
            NDArray label = manager.create(new float[] {0.3f, 0.3f, 0.4f, 0.1f, 0.1f, 0.8f}, new Shape(2,3));
            NDArray expectedLoss = manager.create(new float[] {8.2009115f,  0.70760596f}).mean();

            logits.attachGradient();
            try (GradientCollector gc = manager.getEngine().newGradientCollector()) {
                NDArray probabilities = logits.softmax(1);
                System.out.println("probabilities: ");
                System.out.println(probabilities);

                NDArray handcalculatedGradient = probabilities.sub(label);
                System.out.println("handcalculatedGradient: ");
                System.out.println(handcalculatedGradient);

                NDArray loss = new SoftmaxCrossEntropyLoss("SoftmaxCrossEntropyLoss", 1, 1, sparseLabel, !fromLogit).evaluate(new NDList(label), new NDList(logits));
                Assertions.assertAlmostEquals(loss, expectedLoss);
                gc.backward(loss);
                NDArray autograd = logits.getGradient();
                System.out.println("autograd: ");
                System.out.println(autograd);
                // The following assertion fails
                Assertions.assertAlmostEquals(autograd, handcalculatedGradient);
                // the assertion passes but the factor 2 is wrong
                // Assertions.assertAlmostEquals(handcalculatedGradient, autograd.mul(2));
            }
        }
    }

gives

probabilities: 
ND: (2, 3) gpu(0) float32
[[ 1.52161022e-08,  9.99088883e-01,  9.11051058e-04],
 [ 9.00305659e-02,  2.44728461e-01,  6.65240884e-01],
]

handcalculatedGradient: 
ND: (2, 3) gpu(0) float32
[[-0.3   ,  0.6991, -0.3991],
 [-0.01  ,  0.1447, -0.1348],
]

autograd: 
ND: (2, 3) gpu(0) float32
[[-0.15  ,  0.3495, -0.1995],
 [-0.005 ,  0.0724, -0.0674],
]

java.lang.AssertionError: 
Expected: ND: (2, 3) gpu(0) float32
[[-0.3   ,  0.6991, -0.3991],
 [-0.01  ,  0.1447, -0.1348],
]

Actual: ND: (2, 3) gpu(0) float32
[[-0.15  ,  0.3495, -0.1995],
 [-0.005 ,  0.0724, -0.0674],
]
enpasos commented 3 years ago

... forget about my last comment on the gradient. Didn't take acount the "mean" which makes the factor 2 in the gradient. Thanks again and sorry for the "gradient noise".