deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.12k stars 655 forks source link

java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead. #2685

Closed juliangamble closed 1 year ago

juliangamble commented 1 year ago

Description

Running Deep Java Learning Exercise 3.2 on an M1 Mac in Java 11 with Pytorch leads to the Exception: java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.

Expected Behavior

The code completes and shows a scatterplot of Synthetic Data.

Error Message

[main] INFO ai.djl.pytorch.engine.PtEngine - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 8
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 8
Exception in thread "main" java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.
    at ai.djl.pytorch.engine.PtNDArray.dot(PtNDArray.java:1353)
    at ai.djl.pytorch.engine.PtNDArray.dot(PtNDArray.java:39)
    at org.example.Exercise3_2_LinearRegressionFromScratch.syntheticData(Exercise3_2_LinearRegressionFromScratch.java:45)
    at org.example.Exercise3_2_LinearRegressionFromScratch.main(Exercise3_2_LinearRegressionFromScratch.java:22)

Process finished with exit code 1

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Maven pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>DeepLearningJavaExercises</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>

    <dependencies>
        <!-- https://mvnrepository.com/artifact/ai.djl/api -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.22.1</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.22.1</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.22.1</version>
        </dependency>

        -->
        <!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/tech.tablesaw/tablesaw-core -->
        <dependency>
            <groupId>tech.tablesaw</groupId>
            <artifactId>tablesaw-core</artifactId>
            <version>0.43.1</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/tech.tablesaw/tablesaw-jsplot -->
        <dependency>
            <groupId>tech.tablesaw</groupId>
            <artifactId>tablesaw-jsplot</artifactId>
            <version>0.43.1</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>2.0.7</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>2.0.7</version>
        </dependency>

    </dependencies>

</project>

Java class Exercise3_2_LinearRegressionFromScratch

package org.example;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.plotly.Plot;
import tech.tablesaw.plotly.api.ScatterPlot;
import tech.tablesaw.plotly.components.Figure;

public class Exercise3_2_LinearRegressionFromScratch {
    public static void main(String[] args) {
        NDManager manager = NDManager.newBaseManager();

        NDArray trueW = manager.create(new float[]{2, -3.4f});
        float trueB = 4.2f;

        DataPoints dp = syntheticData(manager, trueW, trueB, 1000);
        NDArray features = dp.getX();
        NDArray labels = dp.getY();

        System.out.printf("features: [%f, %f]\n", features.get(0).getFloat(0), features.get(0).getFloat(1));
        System.out.println("label: " + labels.getFloat(0));

        float[] X = features.get(new NDIndex(":, 1")).toFloatArray();
        float[] y = labels.toFloatArray();

        Table data = Table.create("Data")
                .addColumns(
                        FloatColumn.create("X", X),
                        FloatColumn.create("y", y)
                );

        Figure figure = ScatterPlot.create("Synthetic Data", data, "X", "y");
        Plot.show(figure);
    }

    // Generate y = X w + b + noise
    public static DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
        NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
        NDArray y = X.dot(w).add(b);
        //java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.
        // Add noise
        y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32));
        return new DataPoints(X, y);
    }
}

Steps to reproduce

(Paste the commands you ran that produced the error.)

  1. export DJL_DEFAULT_ENGINE=PyTorch (via IntelliJ run configuration)
  2. Run the class Exercise3_2_LinearRegressionFromScratch

What have you tried to solve it?

  1. Check the environment variables.
  2. Check the code in the book at 3.2 Implement Linear Regression From Scratch https://d2l.djl.ai/chapter_linear-networks/linear-regression-scratch.html

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

Via Intellij configuration.

DJL_DEFAULT_ENGINE=PyTorch
JAVA_HOME=/Library/Java/JavaVirtualMachines/temurin-11.jdk/Contents/Home
KexinFeng commented 1 year ago

Did this specifically happened at this computation?

NDArray y = X.dot(w).add(b);

Have you tried the matMul as suggested in the error message?

I think i spot one possble error that causes dimension mismatch when doing X.dot(w)

    // Generate y = X w + b + noise
    public static DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
        NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
        NDArray y = X.dot(w).add(b);
        //java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.
        // Add noise
        y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32));
        return new DataPoints(X, y);
    }

Here, in NDArray X = manager.randomNormal(new Shape(numExamples, w.size())); Shape(numExamples, w.size()) -> Shape(numExamples, w.getShape().get(0)); Then the dimension would match in that X.dot(w) computation.

juliangamble commented 1 year ago

hi @KexinFeng,

Thanks for your response.

Did this specifically happened at this computation?

Yes.

Have you tried the matMul as suggested in the error message?

There is not enough information for me to go ahead with this. But what I'm using is based on the textbook here: https://d2l.djl.ai/chapter_linear-networks/linear-regression-scratch.html#generating-the-dataset So I'm interested in the reason to do it differently.

I think i spot one possble error that causes dimension mismatch when doing X.dot(w) Here, in NDArray X = manager.randomNormal(new Shape(numExamples, w.size())); Shape(numExamples, w.size()) -> Shape(numExamples, w.getShape().get(0)); Then the dimension would match in that X.dot(w) computation.

Here is what I changed it to:

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.plotly.Plot;
import tech.tablesaw.plotly.api.ScatterPlot;
import tech.tablesaw.plotly.components.Figure;

public class Exercise3_2_LinearRegressionFromScratch {
    public static void main(String[] args) {
        NDManager manager = NDManager.newBaseManager();

        NDArray trueW = manager.create(new float[]{2, -3.4f});
        float trueB = 4.2f;

        DataPoints dp = syntheticData(manager, trueW, trueB, 1000);
        NDArray features = dp.getX();
        NDArray labels = dp.getY();

        System.out.printf("features: [%f, %f]\n", features.get(0).getFloat(0), features.get(0).getFloat(1));
        System.out.println("label: " + labels.getFloat(0));

        float[] X = features.get(new NDIndex(":, 1")).toFloatArray();
        float[] y = labels.toFloatArray();

        Table data = Table.create("Data")
                .addColumns(
                        FloatColumn.create("X", X),
                        FloatColumn.create("y", y)
                );

        Figure figure = ScatterPlot.create("Synthetic Data", data, "X", "y");
        Plot.show(figure);
    }

    // Generate y = X w + b + noise
    public static DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
        //NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
        NDArray X = manager.randomNormal(new Shape(numExamples, w.getShape().get(0)));
        NDArray y = X.dot(w).add(b);
        //java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.
        // Add noise
        y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32));
        return new DataPoints(X, y);
    }
}

It still fails with the error:

[main] INFO ai.djl.pytorch.engine.PtEngine - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 8
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 8
Exception in thread "main" java.lang.UnsupportedOperationException: Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.
    at ai.djl.pytorch.engine.PtNDArray.dot(PtNDArray.java:1353)
    at ai.djl.pytorch.engine.PtNDArray.dot(PtNDArray.java:39)
    at org.example.Exercise3_2_LinearRegressionFromScratch.syntheticData(Exercise3_2_LinearRegressionFromScratch.java:46)
    at org.example.Exercise3_2_LinearRegressionFromScratch.main(Exercise3_2_LinearRegressionFromScratch.java:22)

Process finished with exit code 1

Have you got a suggestion?

KexinFeng commented 1 year ago

The reason of changing it is to make sure that in y = X w + b, where X of shape [s1, s2], w of shape [s3, s4], s2 be the same as s3. This is due to the rule of matrix multiplication.

KexinFeng commented 1 year ago

I found the error. The dot operation only applies on 1D vectors. Please do further change of changing .dot to .matMul

KexinFeng commented 1 year ago

See also https://pytorch.org/docs/stable/generated/torch.dot.html

juliangamble commented 1 year ago

Thankyou - that fixed it.