linkedin / dagli

Framework for defining machine learning models, including feature generation and transformations, as directed acyclic graphs (DAGs).
BSD 2-Clause "Simplified" License
354 stars 40 forks source link

More control over loss functions #12

Open cyberbeat opened 2 years ago

cyberbeat commented 2 years ago

Reading about multilabel problems because of inbalanced label-distribution in training data:

https://arxiv.org/abs/2109.04712

A label-based loss-function re-weighting is proposed. It seems that at least applying a weights-array is possible with dl4j? How could I do that with dagli?

jeffpasternack commented 2 years ago

Do you have an example for the use of a per-label weights-array in DL4J?

Most models in Dagli support per-example weights, but not neural networks, precisely because DL4J did not support per-example weights (and I'm further unaware of per-label weighting). If per-label weighting is possible, it should be reasonably easy to subclass the NeuralNetwork class to add a hook that configures the DL4J graph before training begins to use your chosen weighting.

A...less good...option is of course to duplicate examples according to their desired weights; e.g. repeat an example 5 times if you want it to have a weight of 5 (and if you shuffle the duplicates into the rest of the data it's at least somewhat less wasteful and more akin to making multiple passes over the data as you might do anyway.)

cyberbeat commented 2 years ago

See here: https://github.com/linkedin/dagli/blob/13ebe37f13535e706f1f5ae128a4c93b8bbf7150/nn-dl4j/src/main/java/com/linkedin/dagli/dl4j/LossFunctionConverterVisitor.java#L18

where you use "LossMCXENT":

https://github.com/eclipse/deeplearning4j/blob/fc735d30023981ebbb0fafa55ea9520ec44292e0/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java

Perhaps you could also add an optional weight-array instead of a single double, you are converting it to an array anyway later (see your code above)?

Btw duplicating is difficult with multilabel-data.

jeffpasternack commented 2 years ago

Thanks--so per-label weightings are certainly possible then. Unfortunately, this wouldn't be quite so trivial to implement in Dagli because we'd need to communicate the mapping of labels to indices to the weights(...) method, in addition to applying the right logic for, e.g. binary problems, making sure it works when there are multiple loss "layers", and investigating to make sure that DL4J honors the label weights for other loss functions, too. It's doable, but DL4J support for per-example weights would greatly simplify things.

jeffpasternack commented 2 years ago

Incidentally, another (inconvenient) workaround would be to rephrase your multilabel problem by using a different binary NNClassification corresponding to each label, which you could then weight as desired. It'd be inconvenient because you'd have to convert your label sets to sequences of booleans, but it should work.

cyberbeat commented 2 years ago

I now have an exception, when using a weight:

Exception in thread "main" java.lang.RuntimeException: MultithreadedDAGExecutor terminated execution because it encountered an unexpected exception in a worker thread: java.lang.IllegalArgumentException: Weights array must be a row vector
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor.executeUnsafe(MultithreadedDAGExecutor.java:1546)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor.prepareUnsafeImpl(MultithreadedDAGExecutor.java:1497)
    at com.linkedin.dagli.dag.LocalDAGExecutor.prepareUnsafeImpl(LocalDAGExecutor.java:71)
    at com.linkedin.dagli.dag.AbstractDAGExecutor.prepareUnsafe(AbstractDAGExecutor.java:99)
    at com.linkedin.dagli.dag.DAG1x5.prepare(DAG1x5.java:271)
...
Caused by: java.lang.IllegalArgumentException: Weights array must be a row vector
    at org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT.<init>(LossBinaryXENT.java:100)
    at org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT.<init>(LossBinaryXENT.java:72)
    at com.linkedin.dagli.dl4j.LossFunctionConverterVisitor.visit(LossFunctionConverterVisitor.java:50)
    at com.linkedin.dagli.dl4j.LossFunctionConverterVisitor.visit(LossFunctionConverterVisitor.java:18)
    at com.linkedin.dagli.nn.loss.BinaryCrossEntropyLoss.accept(BinaryCrossEntropyLoss.java:105)
    at com.linkedin.dagli.dl4j.NetworkBuilderLayerVisitor.visit(NetworkBuilderLayerVisitor.java:113)
    at com.linkedin.dagli.dl4j.NetworkBuilderLayerVisitor.visit(NetworkBuilderLayerVisitor.java:76)
    at com.linkedin.dagli.nn.layer.NNClassification.accept(NNClassification.java:192)
    at com.linkedin.dagli.dl4j.NeuralNetwork$Preparer.lambda$initialize$2(NeuralNetwork.java:156)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at com.linkedin.dagli.dl4j.NeuralNetwork$Preparer.initialize(NeuralNetwork.java:156)
    at com.linkedin.dagli.nn.AbstractNeuralNetwork$Preparer.processUnsafe(AbstractNeuralNetwork.java:1211)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$PreparationTask.onRun(MultithreadedDAGExecutor.java:765)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$Task.run(MultithreadedDAGExecutor.java:368)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$Scheduler.lambda$schedule$4(MultithreadedDAGExecutor.java:329)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    at java.base/java.lang.Thread.run(Thread.java:831)

Maybe your weights method is wrong?

  /**
   * @param lossFunction the loss function whose weight {@link INDArray} should be returned
   * @return the weight INDArray corresponding to a loss function, or null if it is not required (weight == 1)
   */
  private INDArray weights(LossFunction lossFunction) {
    return lossFunction.getWeight() == 1 ? null : Nd4j.valueArrayOf(_inputWidth, lossFunction.getWeight());
  }
jeffpasternack commented 2 years ago

Thanks for reporting this. The weights(...) method itself does what's intended--it returns a vector of length _inputWidth (also the output width). Unfortunately, on investigation I discovered that ND4J considers a vector with a single element to not be a "row", and there's not really any workaround I can see short of reimplementing each loss function (hacks such as using a longer vector won't work because DL4J later checks that the weights vector isn't "too long").

We'll add a more informative error when the number of elements is 1 indicating that custom weights aren't supported by DL4J in this case, but that obviously won't fix your problem.

One workaround that may work for you is to use withMultilabelLabelsInput(...) instead of withBinaryLabelInput(...) with your binary labels. This should cause "true" and "false" to be treated as separate labels, so the corresponding input/output width--and thus vector length--will be 2. However, this is hacky because, in the corner case where all the labels are true or false the vector length will again be 1 and I believe you'll get the same "row vector" error as before.

jeffpasternack commented 2 years ago

(You can prevent this corner case by adding two arbitrary "dummy examples", one with all "true" labels and one with all "false", but that strikes me as even more hacky).

cyberbeat commented 2 years ago

Mhm, this seems a dl4j bug - the check for isRowVector should be replaced by isRowVectorOrScalar, right?

cyberbeat commented 2 years ago

This is now fixed in dl4j:

https://github.com/eclipse/deeplearning4j/issues/9582#issuecomment-1001112361

Also samediff loss is available:

https://deeplearning4j.konduit.ai/samediff/reference/operation-namespaces/loss#weightedcrossentropywithlogits

cyberbeat commented 2 years ago

I also found this:

https://community.konduit.ai/t/per-sample-weights-or-label-fractions/706/5

https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/features/customizingdl4j/lossfunctions/CustomLossUsageEx.java

jeffpasternack commented 2 years ago

Thanks for the link regarding the fix in DL4J; you might be able to thus sidestep this by compiling their latest code and importing that into your project to replace the version Dagli uses (beta7), although based on past experience DL4J doesn't tend to maintain API compatibility between versions.

It's certainly also possible to implement new loss layers to overcome the bug, but unfortunately not very practical: we'd have to fork a new loss layer for every existing one, and this would ultimately become moot once we update Dagli's DL4J dependency. Hopefully the workaround of using a multinomial loss for binomial labels is sufficient until then.

cyberbeat commented 2 years ago

I now tried to use a custom neural network. But here I got problems pulling in the inputs for the DL4J EmbeddingSequenceLayer:

new CustomNeuralNetwork()
...
.withFeaturesInputFromNumberSequence("char_sequence", new Indices<>().withMaxUniqueObjects(1000).withMinimumFrequency(10).withInput(truncatedNodeTexts), 50, DataType.INT32)

The Indices Transformer gives Iterables with different sizes. How does Dagli add padding/masking to the input, so that it can be fed to the EmbeddingSequenceLayer?

I got the exception

Exception in thread "main" java.lang.RuntimeException: MultithreadedDAGExecutor terminated execution because it encountered an unexpected exception in a worker thread: java.lang.IllegalStateException: Cannot pull rows into destination array: expected destination array of shape [25000, 16] but got destination array of shape [10100, 16]
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor.executeUnsafe(MultithreadedDAGExecutor.java:1546)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor.prepareUnsafeImpl(MultithreadedDAGExecutor.java:1497)
    at com.linkedin.dagli.dag.LocalDAGExecutor.prepareUnsafeImpl(LocalDAGExecutor.java:71)
    at com.linkedin.dagli.dag.AbstractDAGExecutor.prepareUnsafe(AbstractDAGExecutor.java:99)
    at com.linkedin.dagli.dag.DAG1x1.prepare(DAG1x1.java:253)
...
Caused by: java.lang.IllegalStateException: Cannot pull rows into destination array: expected destination array of shape [25000, 16] but got destination array of shape [10100, 16]
    at org.nd4j.linalg.jcublas.JCublasNDArrayFactory.pullRows(JCublasNDArrayFactory.java:505)
    at org.nd4j.linalg.factory.Nd4j.pullRows(Nd4j.java:4813)
    at org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer.preOutput(EmbeddingSequenceLayer.java:159)
    at org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer.activate(EmbeddingSequenceLayer.java:176)
    at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doForward(LayerVertex.java:110)
    at org.deeplearning4j.nn.graph.ComputationGraph.ffToLayerActivationsInWS(ComputationGraph.java:2135)
    at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1372)
    at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
    at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
    at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
    at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
    at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1044)
    at com.linkedin.dagli.dl4j.AbstractCustomNeuralNetwork$AbstractPreparer.train(AbstractCustomNeuralNetwork.java:744)
    at com.linkedin.dagli.dl4j.AbstractCustomNeuralNetwork$AbstractPreparer.finishUnsafe(AbstractCustomNeuralNetwork.java:762)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$PreparationFinishTask.onRun(MultithreadedDAGExecutor.java:792)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$Task.run(MultithreadedDAGExecutor.java:368)
    at com.linkedin.dagli.dag.MultithreadedDAGExecutor$Scheduler.lambda$schedule$4(MultithreadedDAGExecutor.java:329)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    at java.base/java.lang.Thread.run(Thread.java:831)
cyberbeat commented 2 years ago

I think I did a workaround, from your example:

    // Create the neural network from our computation graph, specifying its feature and label inputs.
    //
    // In reality, the data type we use for the tokenIndices input *should* be INT32, not INT64.  However, when INT32 is
    // used, the current version of DL4J has a bug that causes data to be read past the end of the final minibatch in an
    // epoch when that minibatch has fewer examples than others.  The workaround we adopt is to use INT64, which seems
    // to force DL4J to create a (properly sized) copy of the original data.  This is obviously less efficient than
    // ideal but should have minimal practical impact.
    CustomNeuralNetwork neuralNetwork = new CustomNeuralNetwork().withComputationGraph(graph)
        .withFeaturesInputFromNumberSequence("tokenIndices", tokenIndices, maxTokenLength, DataType.INT64)
        .withLabelInputFromVector("classification", new ManyHotVector().withInputs(labelIndex), maxLabels,
            DataType.FLOAT)
        .withMaxEpochs(50)

You wrote about a bug in DL4J. Is there a bug report on DL4J about that? I am using DL4J M.1.1, bug seems still to be there.

jeffpasternack commented 2 years ago

This workaround is from quite a while ago, but I assume we discovered this issue ourselves. Reporting this bug to the DL4J team would have been the nice thing to do, but we did not do so.