Waikato / wekaDeeplearning4j

Weka package for the Deeplearning4j java library
https://deeplearning.cms.waikato.ac.nz/
GNU General Public License v3.0
185 stars 202 forks source link

EfficientNet model return error #71

Open jSaso opened 2 years ago

jSaso commented 2 years ago

Issue Description

Working example with dummy data:

package com.tim4it.ai;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.net.URL;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.Adler32;

@Slf4j
public class EfficientNetExample {

    final static String MODEL_NAME = "KerasEfficientNetB0.zip";
    final static String MODEL_URL = "https://github.com/Waikato/wekaDeeplearning4j/releases/download/zoo-models/" + MODEL_NAME;
    final static long MODEL_CHECK_SUM = 3915144300L;

    static final String LAST_LAYER_NAME = "probs";
    static final String LAST_EXTRACTION_LAYER = "top_dropout";
    static final String OUTPUT_LAYER = "out";

    static final int LAST_LAYER_OUT = 1280;
    static final int OUTPUT_COUNT = 10;
    static final int INPUT_COUNT = 20;

    public static void main(String... args) {
        var model = modelCreate();
        var createData = IntStream.range(0, INPUT_COUNT)
                .mapToObj(it -> new MultiDataSet(
                        new INDArray[]{modelInput(it)},
                        new INDArray[]{modelOutput(it)}))
                .collect(Collectors.toUnmodifiableList());
        var mergeData = MultiDataSet.merge(createData);
        var mergeDataIterator = new SingletonMultiDataSetIterator(mergeData);
        for (int epoch = 0; epoch < INPUT_COUNT; epoch++) {
            mergeDataIterator.reset();
            model.fit(mergeDataIterator);
            log.info("Epoch {}", epoch);
        }
    }

    public static ComputationGraph modelCreate() {
        var efficientNetB0 = modelLoad();

        var fineTune = new FineTuneConfiguration.Builder()
                .seed(221342347234L)
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .build();

        return new TransferLearning.GraphBuilder(efficientNetB0)
                .fineTuneConfiguration(fineTune)
                .removeVertexKeepConnections(LAST_LAYER_NAME)
                .addLayer(OUTPUT_LAYER,
                        new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                .nIn(LAST_LAYER_OUT).nOut(OUTPUT_COUNT)
                                .weightInit(new NormalDistribution(0, 0.2 * (2.0 / (4096 + OUTPUT_COUNT))))
                                .activation(Activation.SOFTMAX).build(),
                        LAST_EXTRACTION_LAYER)
                .setOutputs(OUTPUT_LAYER)
                .build();
    }

    public static INDArray modelInput(int number) {
        var len = 224 * 224 * 3;
        var result = new float[len];
        for (int i = 0; i < len; i++) {
            int numberTmp = 15;
            if (i % 2 == 0) {
                numberTmp += number;
            } else {
                numberTmp += number + 22;
            }
            result[i] = numberTmp / 255f;
        }
        var features = Nd4j.create(result, 1, 3, 224, 224);
        return features.permute(0, 2, 3, 1);
    }

    public static INDArray modelOutput(int number) {
        var labels = Nd4j.zeros(1, OUTPUT_COUNT);
        return labels.putScalar((number < OUTPUT_COUNT ? number : (INPUT_COUNT - number - 1)), 1);
    }

    public static ComputationGraph modelLoad() {
        try {
            // Set up file locations
            var cachedFile = new File(System.getProperty("java.io.tmpdir"), MODEL_NAME);
            FileUtils.copyURLToFile(new URL(MODEL_URL), cachedFile);
            // Validate the checksum - ensure this is the correct file
            var adler = new Adler32();
            FileUtils.checksum(cachedFile, adler);
            long localChecksum = adler.getValue();
            if (MODEL_CHECK_SUM != localChecksum) {
                throw new IllegalStateException("Pretrained model file for model " + MODEL_NAME + " failed checksum!");
            }
            // Load the .zip file to a ComputationGraph
            return ModelSerializer.restoreComputationGraph(cachedFile);
        } catch (Exception ex) {
            throw new IllegalStateException("Error loading model!", ex);
        }
    }
}

Copy class (or use weka class) with package to your project - in order for EfficientNet model to work: CustomBroadcast.java

Error

Error most likely in conversation between NCHW to NHWC - conversation between NDArray dimensions.

20:26:18.315 [main] INFO  - Loaded [CpuBackend] backend
20:26:18.681 [main] INFO  - Number of threads used for linear algebra: 8
20:26:18.683 [main] INFO  - Binary level AVX/AVX2 optimization level AVX/AVX2
20:26:18.687 [main] INFO  - Number of threads used for OpenMP BLAS: 8
20:26:18.690 [main] INFO  - Backend used: [CPU]; OS: [Linux]
20:26:18.690 [main] INFO  - Cores: [16]; Memory: [30.0GB];
20:26:18.690 [main] INFO  - Blas vendor: [OPENBLAS]
20:26:18.692 [main] INFO  - Backend build information:
 GCC: "7.5.0"
STD version: 201103L
DEFAULT_ENGINE: samediff::ENGINE_CPU
HAVE_FLATBUFFERS
HAVE_MKLDNN
HAVE_OPENBLAS
20:26:19.026 [main] INFO  - Starting ComputationGraph with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
Error at [/home/runner/work/deeplearning4j/deeplearning4j/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp:41:0]:
MULTIPLY OP: the shapes of x [20, 320, 7, 7] and y [20, 7, 7, 320] are not suitable for broadcast !
20:26:21.165 [main] ERROR - Failed to execute op multiply. Attempted to execute with 2 inputs, 1 outputs, 0 targs,0 bargs and 0 iargs. Inputs: [(FLOAT,[20,320,7,7],c), (FLOAT,[20,7,7,320],c)]. Outputs: [(FLOAT,[20,320,7,7],c)]. tArgs: -. iArgs: -. bArgs: -. Op own name: "70e10b8b-4b30-4d9b-af46-b18a19daca03" - Please see above message (printed out from c++) for a possible cause of error.
Exception in thread "main" java.lang.RuntimeException: Op [multiply] execution failed
    at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1583)
    at org.deeplearning4j.nn.conf.dropout.Dropout.backprop(Dropout.java:202)
    at org.deeplearning4j.nn.layers.AbstractLayer.backpropDropOutIfPresent(AbstractLayer.java:307)
    at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.backpropGradient(ConvolutionLayer.java:224)
    at org.deeplearning4j.nn.graph.vertex.impl.LayerVertex.doBackward(LayerVertex.java:148)
    at org.deeplearning4j.nn.graph.ComputationGraph.calcBackpropGradients(ComputationGraph.java:2772)
    at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1381)
    at org.deeplearning4j.nn.graph.ComputationGraph.computeGradientAndScore(ComputationGraph.java:1341)
    at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:174)
    at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:61)
    at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
    at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1165)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1115)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1082)
    at com.tim4it.ai.EfficientNetExample.main(EfficientNetExample.java:51)
Caused by: java.lang.RuntimeException: Op validation failed
    at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1918)
    at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:1562)
    ... 14 more

Versions: