bytedeco / javacpp-presets

The missing Java distribution of native C++ libraries
Other
2.66k stars 741 forks source link

[pytorch] [java/scala] new version remove Module class convert to layerImpl ,cause the layerImpl cannot covert to Module ! #1393

Closed mullerhai closed 1 year ago

mullerhai commented 1 year ago

HI , when I use new version pytorch 2.0.1 in scala 2.11.10, meet the class convert error, the example java code I rewrite in scala, but it cannot work ,because in javacpp-pytorch new version ,Module.class has remove all the layerImpl convert to Module ,use register_module method, why remove them? now the error is

Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class org.bytedeco.pytorch.LinearImpl (org.bytedeco.pytorch.Module and org.bytedeco.pytorch.LinearImpl are in unnamed module of loader 'app')
    at SimpleMNIST$Net.<init>(hell.scala:23)
    at SimpleMNIST$.main(hell.scala:52)
    at SimpleMNIST.main(hell.scala)

how to solve that error ,do I need import some method sugar dependency in scala code?
if I scala code remove asInstanceOf[LinearImpl] these code ,the scala code cannot compile, Please help me ,thanks dependency:

ThisBuild / version := "0.1.0-SNAPSHOT"

ThisBuild / scalaVersion := "2.12.10"

lazy val root = (project in file("."))
  .settings(
    name := "torchSa"
  )

scalaVersion := "2.12.10"

//idePackagePrefix := Some("org.example")
resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"

val sparkVersion = "3.1.1"

//libraryDependencies ++= Seq(
//  "org.apache.spark" %% "spark-core" % sparkVersion,
//  "org.apache.spark" %% "spark-sql" % sparkVersion,
//  "org.apache.spark" %% "spark-mllib" % sparkVersion,
//  "org.apache.spark" %% "spark-streaming" % sparkVersion
//)
// https://mvnrepository.com/artifact/org.apache.parquet/parquet-common
libraryDependencies += "org.apache.parquet" % "parquet-common" % "1.12.3"

libraryDependencies += "org.bytedeco" % "pytorch" %  "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform
libraryDependencies += "org.bytedeco" % "pytorch-platform" % "2.0.1-1.5.10-SNAPSHOT"  //  "1.12.1-1.5.8" //"1.10.2-1.5.7"
//libraryDependencies += "org.bytedeco" % "pytorch-platform-gpu" %  "2.0.1-1.5.10-SNAPSHOT" // "1.12.1-1.5.8" // "1.10.2-1.5.7"
//// https://mvnrepository.com/artifact/org.bytedeco/pytorch-platform

libraryDependencies += "org.bytedeco" % "mkl-platform-redist" % "2023.1-1.5.10-SNAPSHOT"  //  "1.12.1-1.5.8" //"1.10.2-1.5.7"
//

code : convert the example java code to scala

import org.bytedeco.javacpp._
import org.bytedeco.pytorch._
import org.bytedeco.pytorch.Module
import org.bytedeco.pytorch.global.torch._
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
object SimpleMNIST { // Define a new Module. :LinearImpl :LinearImpl=
  class Net() extends Module { // Construct and register two Linear submodules.
    //fc1 = register_module("fc1", new LinearImpl(784, 64));
    var fc1 = register_module("fc1", new LinearImpl(784, 64)).asInstanceOf[LinearImpl]
    var  fc2 = register_module("fc2", new LinearImpl(64, 32)).asInstanceOf[LinearImpl]
    var  fc3 = register_module("fc3", new LinearImpl(32, 10)).asInstanceOf[LinearImpl]

    // Implement the Net's algorithm.
    def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
      var x = xl
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x,  0.5,  is_training)
      x = relu(fc2.asInstanceOf[LinearImpl].forward(x))
      x = log_softmax(fc3.asInstanceOf[LinearImpl].forward(x),  1)
      x
    }

//     Use one of many "standard library" modules.
//    var fc1: LinearImpl = null
//    var fc2: LinearImpl = null
//    var fc3: LinearImpl = null
  }

  @throws[Exception]
  def main(args: Array[String]): Unit = {
    /* try to use MKL when available */
    System.setProperty("org.bytedeco.openblas.load", "mkl")

    // Create a multi-threaded data loader for the MNIST dataset.
    val data_set = new MNIST("./data").map(new ExampleStack)
    val data_loader = new MNISTRandomDataLoader(data_set, new RandomSampler(data_set.size.get), new DataLoaderOptions(/*batch_size=*/ 64))
    // Create a new Net.
    val net = new SimpleMNIST.Net
    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    val optimizer = new SGD(net.parameters, new SGDOptions(/*lr=*/ 0.01))
    for (epoch <- 1 to 10) {
      var batch_index = 0
      // Iterate the data loader to yield batches from the dataset.
      var it = data_loader.begin
      while ( {
        !(it == data_loader.end)
      }) {
        val batch = it.access
        // Reset gradients.
        optimizer.zero_grad()
        // Execute the model on the input data.
        val prediction = net.forward(batch.data)
        // Compute a loss value to judge the prediction of our model.
        val loss = nll_loss(prediction, batch.target)
        // Compute gradients of the loss w.r.t. the parameters of our model.
        loss.backward()
        // Update the parameters based on the calculated gradients.
        optimizer.step
        // Output the loss and checkpoint every 100 batches.
        if ( {
          batch_index += 1; batch_index
        } % 100 == 0) {
          System.out.println("Epoch: " + epoch + " | Batch: " + batch_index + " | Loss: " + loss.item_float)
          // Serialize your model periodically as a checkpoint.
          val archive = new OutputArchive
          net.save(archive)
          archive.save_to("net.pt")
        }

        it = it.increment
      }
    }
  }
}
HGuillemet commented 1 year ago

The prototype of the register_module method is:

 public Module register_module(String name, Module module);

In previous version, there was, in addition, specialized methods like:

 public LinearImpl register_module(String name, LinearImpl module);

but these were a workaround for a bug that has been fixed and they are not needed anymore. You can now write:

LinearImpl fc1 = new LinearImpl(784, 64);
register_module("fc1", fc1);
mullerhai commented 1 year ago

The prototype of the register_module method is:

 public Module register_module(String name, Module module);

In previous version, there was, in addition, specialized methods like:

 public LinearImpl register_module(String name, LinearImpl module);

but these were a workaround for a bug that has been fixed and they are not needed anymore. You can now write:

LinearImpl fc1 = new LinearImpl(784, 64);
register_module("fc1", fc1);

It is work fine ,but now I has another question, the new version ,we has implement the SequentialImpl.class, for the normal layer in pytorch. we can easy put or push_back to the SequentialImpl, It is can work . but now I need to user_defined layer or model list want to add to the SequentialImpl ,but it is is not work. for example

    val seqs = new SequentialImpl()
    var fc4= new LinearImpl(784, 64)
    var fc5 = new LinearImpl(64, 32)
    var fc6 = new LinearImpl(32, 10)
    seqs.push_back(fc4)
    seqs.push_back(fc5)
    seqs.push_back(fc6)

it can work

but next user_defined model add to SequentialImpl cannot work

  class HelenLayer() extends Module { // Construct and register two Linear submodules.

    var fc1= new LinearImpl(784, 64)
    register_module("fc1",fc1)
    var fc2 = new LinearImpl(64, 32)
    register_module("fc2",fc2)
    var fc3 = new LinearImpl(32, 10)
    register_module("fc3",fc3)

    // Implement the Net's algorithm.
    def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
      var x = xl
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x,  0.5,  is_training)
      x = relu(fc2.forward(x))
      x = log_softmax(fc3.forward(x),  1)
      x
    }

  }

val  layerSeqs = new SequentialImpl()

val helenLayer = new HelenLayer()
layerSeqs.push_back(helenLayer)  // compiler error
layerSeqs.put(helenLayer)  //  meet  jvm error

so how to use SequentialImpl for a list of user_defined layer or model? thanks

mullerhai commented 1 year ago

Now I has implement user defined layer use SequentialImple named SeqNow like our example code Net Model simpleMnist, the code can running ,but cannot decrease the loss, I don't know why.

  class SeqNow() extends Module {
    var seqs = new SequentialImpl()
    var fc4 = new LinearImpl(784, 64)
    var relu = new ReLUImpl()
    val dropOpt = new DropoutOptions()
    var drop = new DropoutImpl(0.5)
    var fc5 = new LinearImpl(64, 32)
    var relu2 = new ReLUImpl()
    var fc6 = new LinearImpl(32, 10)
    val log_softmax = new LogSoftmaxImpl(1)
    seqs.push_back(fc4)
    seqs.push_back(relu)
    seqs.push_back(drop)
    seqs.push_back(fc5)
    seqs.push_back(relu2)
    seqs.push_back(fc6)
    seqs.push_back(log_softmax)
    def forward(xl: Tensor): Tensor = {
      var x = xl.reshape(xl.size(0), 784)
      x = seqs.forward(x)
      x
    }
  }
  class Net() extends Module { // Construct and register two Linear submodules.

    var fc1 = new LinearImpl(784, 64)
    register_module("fc1", fc1)
    var fc2 = new LinearImpl(64, 32)
    register_module("fc2", fc2)
    var fc3 = new LinearImpl(32, 10)
    register_module("fc3", fc3)

    // Implement the Net's algorithm.
    def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
      var x = xl
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x, 0.5, is_training)
      x = relu(fc2.forward(x))
      x = log_softmax(fc3.forward(x), 1)
      x
    }
  }
  @throws[Exception]
  def main(args: Array[String]): Unit = {
    /* try to use MKL when available */
    System.setProperty("org.bytedeco.openblas.load", "mkl")
    // Create a new Net.
    val net = new SimpleMNIST.Net
    val seqs = new SequentialImpl()
    val seqNow = new SimpleMNIST.SeqNow()
    // Create a multi-threaded data loader for the MNIST dataset.
    val data_set = new MNIST("./data").map(new ExampleStack)
    val data_loader = new MNISTRandomDataLoader(data_set, new RandomSampler(data_set.size.get), new DataLoaderOptions(/*batch_size=*/ 32))

    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    val optimizer = new SGD(seqNow.parameters, new SGDOptions(/*lr=*/ 0.01))
    //    val optimizer = new SGD(net.parameters, new SGDOptions(/*lr=*/ 0.01))
    for (epoch <- 1 to 10) {
      var batch_index = 0
      // Iterate the data loader to yield batches from the dataset.
      var it = data_loader.begin
      while ( {
        !it.equals(data_loader.end)
      }){
      //        while ( {
      //        !(it == data_loader.end)
      //      }) {
        val batch = it.access
        // Reset gradients.
        optimizer.zero_grad()
        // Execute the model on the input data.
        //        val prediction = net.forward(batch.data)
        val prediction = seqNow.forward(batch.data)
        // Compute a loss value to judge the prediction of our model.
        val loss = nll_loss(prediction, batch.target)
        // Compute gradients of the loss w.r.t. the parameters of our model.
        loss.backward()
        // Update the parameters based on the calculated gradients.
        optimizer.step
        // Output the loss and checkpoint every 100 batches.
        if ( {
          batch_index += 1;
          batch_index
        } % 100 == 0) {
          System.out.println("Epoch: " + epoch + " | Batch: " + batch_index + " | Loss: " + loss.item_float)
          // Serialize your model periodically as a checkpoint.
          val archive = new OutputArchive
          //          net.save(archive)
          archive.save_to("net.pt")
        }

        it = it.increment
      }
    }
  }

the console seqNow ··· /Library/Java/JavaVirtualMachines/adoptopenjdk-15.jdk/Contents/Home/bin/java -javaagent:/Applications/IntelliJ IDEA.app/Contents/lib/idea_rt.jar=59214:/Applications/IntelliJ IDEA.app/Contents/bin -Dfile.encoding=UTF-8 -classpath /Users/zhanghaining/Documents/codeWorld/untitled/target/scala-2.12/classes:/Users/zhanghaining/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/apache/parquet/parquet-common/1.12.3/parquet-common-1.12.3.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/apache/parquet/parquet-format-structures/1.12.3/parquet-format-structures-1.12.3.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp-platform/1.5.10-SNAPSHOT/javacpp-platform-1.5.10-20230726.101042-118.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-android-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-android-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-ios-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-ios-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-linux-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-linux-ppc64le.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-linux-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-macosx-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-macosx-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/javacpp/1.5.10-SNAPSHOT/javacpp-1.5.10-20230720.003413-82-windows-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl-platform-redist/2023.1-1.5.10-SNAPSHOT/mkl-platform-redist-2023.1-1.5.10-20230718.125906-19.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl-platform/2023.1-1.5.10-SNAPSHOT/mkl-platform-2023.1-1.5.10-20230718.130143-27.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-linux-x86_64-redist.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-linux-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-macosx-x86_64-redist.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-macosx-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-windows-x86_64-redist.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/mkl/2023.1-1.5.10-SNAPSHOT/mkl-2023.1-1.5.10-20230718.130152-23-windows-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas-platform/0.3.23-1.5.10-SNAPSHOT/openblas-platform-0.3.23-1.5.10-20230726.101049-27.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-android-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-android-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-ios-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-ios-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-linux-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-linux-ppc64le.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-linux-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-macosx-arm64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-macosx-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/openblas/0.3.23-1.5.10-SNAPSHOT/openblas-0.3.23-1.5.10-20230608.114324-11-windows-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch-platform/2.0.1-1.5.10-SNAPSHOT/pytorch-platform-2.0.1-1.5.10-20230725.183239-19.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.0.1-1.5.10-SNAPSHOT/pytorch-2.0.1-1.5.10-20230726.101052-36.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.0.1-1.5.10-SNAPSHOT/pytorch-2.0.1-1.5.10-20230726.101052-36-linux-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.0.1-1.5.10-SNAPSHOT/pytorch-2.0.1-1.5.10-20230726.101052-36-macosx-x86_64.jar:/Users/zhanghaining/Library/Caches/Coursier/v1/https/oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.0.1-1.5.10-SNAPSHOT/pytorch-2.0.1-1.5.10-20230726.101052-36-windows-x86_64.jar:/Users/zhanghaining/.m2/repository/org/scala-lang/scala-library/2.12.10/scala-library-2.12.10.jar:/Users/zhanghaining/.m2/repository/org/slf4j/slf4j-api/1.7.22/slf4j-api-1.7.22.jar SimpleMNIST Epoch: 1 | Batch: 100 | Loss: 2.311179 Epoch: 1 | Batch: 200 | Loss: 2.3133073 Epoch: 1 | Batch: 300 | Loss: 2.3120406 Epoch: 1 | Batch: 400 | Loss: 2.2905695 Epoch: 1 | Batch: 500 | Loss: 2.327793 Epoch: 1 | Batch: 600 | Loss: 2.2965953 Epoch: 1 | Batch: 700 | Loss: 2.3046 Epoch: 1 | Batch: 800 | Loss: 2.2857423 Epoch: 1 | Batch: 900 | Loss: 2.2794228 Epoch: 1 | Batch: 1000 | Loss: 2.2959578 Epoch: 1 | Batch: 1100 | Loss: 2.3031638 Epoch: 1 | Batch: 1200 | Loss: 2.253742 Epoch: 1 | Batch: 1300 | Loss: 2.2744474 Epoch: 1 | Batch: 1400 | Loss: 2.299021 Epoch: 1 | Batch: 1500 | Loss: 2.3229873 Epoch: 1 | Batch: 1600 | Loss: 2.29115 Epoch: 1 | Batch: 1700 | Loss: 2.3283224 Epoch: 1 | Batch: 1800 | Loss: 2.3274283 Epoch: 2 | Batch: 100 | Loss: 2.296998 Epoch: 2 | Batch: 200 | Loss: 2.3092 Epoch: 2 | Batch: 300 | Loss: 2.3077636 Epoch: 2 | Batch: 400 | Loss: 2.3423984 Epoch: 2 | Batch: 500 | Loss: 2.3174124 Epoch: 2 | Batch: 600 | Loss: 2.2949598 Epoch: 2 | Batch: 700 | Loss: 2.286696 Epoch: 2 | Batch: 800 | Loss: 2.3047907 Epoch: 2 | Batch: 900 | Loss: 2.3180234 Epoch: 2 | Batch: 1000 | Loss: 2.3170154 Epoch: 2 | Batch: 1100 | Loss: 2.290569 Epoch: 2 | Batch: 1200 | Loss: 2.3148155 Epoch: 2 | Batch: 1300 | Loss: 2.278207 Epoch: 2 | Batch: 1400 | Loss: 2.3145657 Epoch: 2 | Batch: 1500 | Loss: 2.2870746 Epoch: 2 | Batch: 1600 | Loss: 2.3499207 Epoch: 2 | Batch: 1700 | Loss: 2.2855344 Epoch: 2 | Batch: 1800 | Loss: 2.2798202 Epoch: 3 | Batch: 100 | Loss: 2.3229082 Epoch: 3 | Batch: 200 | Loss: 2.3037894 Epoch: 3 | Batch: 300 | Loss: 2.2876005 Epoch: 3 | Batch: 400 | Loss: 2.3212254 Epoch: 3 | Batch: 500 | Loss: 2.3015542 Epoch: 3 | Batch: 600 | Loss: 2.2900429 Epoch: 3 | Batch: 700 | Loss: 2.303884 Epoch: 3 | Batch: 800 | Loss: 2.3447232 Epoch: 3 | Batch: 900 | Loss: 2.3251038 Epoch: 3 | Batch: 1000 | Loss: 2.320521 Epoch: 3 | Batch: 1100 | Loss: 2.301875 Epoch: 3 | Batch: 1200 | Loss: 2.3062909 Epoch: 3 | Batch: 1300 | Loss: 2.3227658 Epoch: 3 | Batch: 1400 | Loss: 2.3148131 Epoch: 3 | Batch: 1500 | Loss: 2.3026087 Epoch: 3 | Batch: 1600 | Loss: 2.3217309 Epoch: 3 | Batch: 1700 | Loss: 2.3559098 Epoch: 3 | Batch: 1800 | Loss: 2.2937589 Epoch: 4 | Batch: 100 | Loss: 2.301482 Epoch: 4 | Batch: 200 | Loss: 2.3086126 Epoch: 4 | Batch: 300 | Loss: 2.298618 Epoch: 4 | Batch: 400 | Loss: 2.3058872 Epoch: 4 | Batch: 500 | Loss: 2.2999983 Epoch: 4 | Batch: 600 | Loss: 2.3193781 Epoch: 4 | Batch: 700 | Loss: 2.295127 Epoch: 4 | Batch: 800 | Loss: 2.2815807 Epoch: 4 | Batch: 900 | Loss: 2.3085556 Epoch: 4 | Batch: 1000 | Loss: 2.3251822 Epoch: 4 | Batch: 1100 | Loss: 2.2811594 Epoch: 4 | Batch: 1200 | Loss: 2.2763584 Epoch: 4 | Batch: 1300 | Loss: 2.291853 Epoch: 4 | Batch: 1400 | Loss: 2.323418 Epoch: 4 | Batch: 1500 | Loss: 2.320117 Epoch: 4 | Batch: 1600 | Loss: 2.2972112 Epoch: 4 | Batch: 1700 | Loss: 2.2927501 Epoch: 4 | Batch: 1800 | Loss: 2.260505 Epoch: 5 | Batch: 100 | Loss: 2.3131657 Epoch: 5 | Batch: 200 | Loss: 2.309602 Epoch: 5 | Batch: 300 | Loss: 2.2837446 Epoch: 5 | Batch: 400 | Loss: 2.3252475 Epoch: 5 | Batch: 500 | Loss: 2.3113067 Epoch: 5 | Batch: 600 | Loss: 2.2943766 Epoch: 5 | Batch: 700 | Loss: 2.3224854 Epoch: 5 | Batch: 800 | Loss: 2.29428 Epoch: 5 | Batch: 900 | Loss: 2.3289096 Epoch: 5 | Batch: 1000 | Loss: 2.3024058 Epoch: 5 | Batch: 1100 | Loss: 2.3023934 Epoch: 5 | Batch: 1200 | Loss: 2.3290997 Epoch: 5 | Batch: 1300 | Loss: 2.3295288 Epoch: 5 | Batch: 1400 | Loss: 2.2765558 Epoch: 5 | Batch: 1500 | Loss: 2.2912512 Epoch: 5 | Batch: 1600 | Loss: 2.2961147 Epoch: 5 | Batch: 1700 | Loss: 2.2827473 Epoch: 5 | Batch: 1800 | Loss: 2.2663298 Epoch: 6 | Batch: 100 | Loss: 2.3150187 Epoch: 6 | Batch: 200 | Loss: 2.3091505 Epoch: 6 | Batch: 300 | Loss: 2.2821596 Epoch: 6 | Batch: 400 | Loss: 2.2877693 Epoch: 6 | Batch: 500 | Loss: 2.281046 Epoch: 6 | Batch: 600 | Loss: 2.3209 Epoch: 6 | Batch: 700 | Loss: 2.3175645 Epoch: 6 | Batch: 800 | Loss: 2.3180046 Epoch: 6 | Batch: 900 | Loss: 2.328904 Epoch: 6 | Batch: 1000 | Loss: 2.3322976 Epoch: 6 | Batch: 1100 | Loss: 2.3013334 Epoch: 6 | Batch: 1200 | Loss: 2.3073165 Epoch: 6 | Batch: 1300 | Loss: 2.3061116 Epoch: 6 | Batch: 1400 | Loss: 2.3281763 Epoch: 6 | Batch: 1500 | Loss: 2.2985666 Epoch: 6 | Batch: 1600 | Loss: 2.3172383 Epoch: 6 | Batch: 1700 | Loss: 2.2991989 Epoch: 6 | Batch: 1800 | Loss: 2.3242373 Epoch: 7 | Batch: 100 | Loss: 2.31733 Epoch: 7 | Batch: 200 | Loss: 2.305778 Epoch: 7 | Batch: 300 | Loss: 2.2901695 Epoch: 7 | Batch: 400 | Loss: 2.354087 Epoch: 7 | Batch: 500 | Loss: 2.2955165 Epoch: 7 | Batch: 600 | Loss: 2.298453 Epoch: 7 | Batch: 700 | Loss: 2.3135612 Epoch: 7 | Batch: 800 | Loss: 2.3128998 Epoch: 7 | Batch: 900 | Loss: 2.315315 Epoch: 7 | Batch: 1000 | Loss: 2.2852345 Epoch: 7 | Batch: 1100 | Loss: 2.2933066 Epoch: 7 | Batch: 1200 | Loss: 2.3040879 Epoch: 7 | Batch: 1300 | Loss: 2.3110313 Epoch: 7 | Batch: 1400 | Loss: 2.3072937 Epoch: 7 | Batch: 1500 | Loss: 2.2954926 Epoch: 7 | Batch: 1600 | Loss: 2.330746 Epoch: 7 | Batch: 1700 | Loss: 2.2816267 Epoch: 7 | Batch: 1800 | Loss: 2.330859 Epoch: 8 | Batch: 100 | Loss: 2.3279943 Epoch: 8 | Batch: 200 | Loss: 2.304054 Epoch: 8 | Batch: 300 | Loss: 2.3247418 Epoch: 8 | Batch: 400 | Loss: 2.2978754 Epoch: 8 | Batch: 500 | Loss: 2.3031363 Epoch: 8 | Batch: 600 | Loss: 2.3402176 Epoch: 8 | Batch: 700 | Loss: 2.3024223 Epoch: 8 | Batch: 800 | Loss: 2.3355234 Epoch: 8 | Batch: 900 | Loss: 2.2986102 Epoch: 8 | Batch: 1000 | Loss: 2.3087595 Epoch: 8 | Batch: 1100 | Loss: 2.28528 Epoch: 8 | Batch: 1200 | Loss: 2.3398473 Epoch: 8 | Batch: 1300 | Loss: 2.3271775 Epoch: 8 | Batch: 1400 | Loss: 2.3006303 Epoch: 8 | Batch: 1500 | Loss: 2.284148 Epoch: 8 | Batch: 1600 | Loss: 2.2964175 Epoch: 8 | Batch: 1700 | Loss: 2.293785 Epoch: 8 | Batch: 1800 | Loss: 2.3146398 Epoch: 9 | Batch: 100 | Loss: 2.3038425 Epoch: 9 | Batch: 200 | Loss: 2.2924495 Epoch: 9 | Batch: 300 | Loss: 2.3044071 Epoch: 9 | Batch: 400 | Loss: 2.3272884 Epoch: 9 | Batch: 500 | Loss: 2.275878 Epoch: 9 | Batch: 600 | Loss: 2.3423223 Epoch: 9 | Batch: 700 | Loss: 2.2765942 Epoch: 9 | Batch: 800 | Loss: 2.3106685 Epoch: 9 | Batch: 900 | Loss: 2.3071628 Epoch: 9 | Batch: 1000 | Loss: 2.3144343 Epoch: 9 | Batch: 1100 | Loss: 2.289462 Epoch: 9 | Batch: 1200 | Loss: 2.2881138 Epoch: 9 | Batch: 1300 | Loss: 2.3021023 Epoch: 9 | Batch: 1400 | Loss: 2.304129 Epoch: 9 | Batch: 1500 | Loss: 2.3375525 Epoch: 9 | Batch: 1600 | Loss: 2.289328 Epoch: 9 | Batch: 1700 | Loss: 2.2969732 Epoch: 9 | Batch: 1800 | Loss: 2.3206847 Epoch: 10 | Batch: 100 | Loss: 2.3322062 Epoch: 10 | Batch: 200 | Loss: 2.3208215 Epoch: 10 | Batch: 300 | Loss: 2.3034637 Epoch: 10 | Batch: 400 | Loss: 2.3149908 Epoch: 10 | Batch: 500 | Loss: 2.3124666 Epoch: 10 | Batch: 600 | Loss: 2.2874281 Epoch: 10 | Batch: 700 | Loss: 2.2848132 Epoch: 10 | Batch: 800 | Loss: 2.3126278 Epoch: 10 | Batch: 900 | Loss: 2.3002985 Epoch: 10 | Batch: 1000 | Loss: 2.2808762 Epoch: 10 | Batch: 1100 | Loss: 2.2735288 Epoch: 10 | Batch: 1200 | Loss: 2.3159194 Epoch: 10 | Batch: 1300 | Loss: 2.2830477 Epoch: 10 | Batch: 1400 | Loss: 2.3007078 Epoch: 10 | Batch: 1500 | Loss: 2.2847142 Epoch: 10 | Batch: 1600 | Loss: 2.3029075 Epoch: 10 | Batch: 1700 | Loss: 2.2893977 Epoch: 10 | Batch: 1800 | Loss: 2.3192947

···

mullerhai commented 1 year ago

for the net console ···

Epoch: 1 | Batch: 100 | Loss: 2.2619615 Epoch: 1 | Batch: 200 | Loss: 2.2245278 Epoch: 1 | Batch: 300 | Loss: 2.2066588 Epoch: 1 | Batch: 400 | Loss: 2.1437888 Epoch: 1 | Batch: 500 | Loss: 2.0107937 Epoch: 1 | Batch: 600 | Loss: 1.7830594 Epoch: 1 | Batch: 700 | Loss: 1.6636932 Epoch: 1 | Batch: 800 | Loss: 1.624912 Epoch: 1 | Batch: 900 | Loss: 1.3674096 Epoch: 1 | Batch: 1000 | Loss: 1.3397406 Epoch: 1 | Batch: 1100 | Loss: 1.119798 Epoch: 1 | Batch: 1200 | Loss: 0.8827966 Epoch: 1 | Batch: 1300 | Loss: 1.1612543 Epoch: 1 | Batch: 1400 | Loss: 0.6524454 Epoch: 1 | Batch: 1500 | Loss: 0.94122744 Epoch: 1 | Batch: 1600 | Loss: 0.75724584 Epoch: 1 | Batch: 1700 | Loss: 1.0215834 Epoch: 1 | Batch: 1800 | Loss: 0.7179247 Epoch: 2 | Batch: 100 | Loss: 1.1135714 Epoch: 2 | Batch: 200 | Loss: 0.8856457 Epoch: 2 | Batch: 300 | Loss: 0.8903471 Epoch: 2 | Batch: 400 | Loss: 0.6036425 Epoch: 2 | Batch: 500 | Loss: 0.47791797 Epoch: 2 | Batch: 600 | Loss: 0.7230946 Epoch: 2 | Batch: 700 | Loss: 0.79854256 Epoch: 2 | Batch: 800 | Loss: 0.39642116 Epoch: 2 | Batch: 900 | Loss: 0.3968962 Epoch: 2 | Batch: 1000 | Loss: 0.5763106 Epoch: 2 | Batch: 1100 | Loss: 0.6598583 Epoch: 2 | Batch: 1200 | Loss: 0.43888167 Epoch: 2 | Batch: 1300 | Loss: 0.6808061 Epoch: 2 | Batch: 1400 | Loss: 0.4455229 Epoch: 2 | Batch: 1500 | Loss: 0.5489536 Epoch: 2 | Batch: 1600 | Loss: 0.7090426 Epoch: 2 | Batch: 1700 | Loss: 0.28528082 Epoch: 2 | Batch: 1800 | Loss: 0.90619284 Epoch: 3 | Batch: 100 | Loss: 0.63242817 Epoch: 3 | Batch: 200 | Loss: 0.43466067 Epoch: 3 | Batch: 300 | Loss: 0.85634965 Epoch: 3 | Batch: 400 | Loss: 0.581641 Epoch: 3 | Batch: 500 | Loss: 0.44870692 Epoch: 3 | Batch: 600 | Loss: 0.7503278 Epoch: 3 | Batch: 700 | Loss: 0.44492963 Epoch: 3 | Batch: 800 | Loss: 0.40738276 Epoch: 3 | Batch: 900 | Loss: 0.3926798 Epoch: 3 | Batch: 1000 | Loss: 0.42284447 Epoch: 3 | Batch: 1100 | Loss: 0.2804313 Epoch: 3 | Batch: 1200 | Loss: 0.5176494 Epoch: 3 | Batch: 1300 | Loss: 0.20760004 Epoch: 3 | Batch: 1400 | Loss: 0.65139306 Epoch: 3 | Batch: 1500 | Loss: 0.29561076 Epoch: 3 | Batch: 1600 | Loss: 0.20892558 Epoch: 3 | Batch: 1700 | Loss: 0.45360973 Epoch: 3 | Batch: 1800 | Loss: 0.43093768 Epoch: 4 | Batch: 100 | Loss: 0.3583884 Epoch: 4 | Batch: 200 | Loss: 0.28905255 Epoch: 4 | Batch: 300 | Loss: 0.37375352 Epoch: 4 | Batch: 400 | Loss: 0.43621093 Epoch: 4 | Batch: 500 | Loss: 0.4721342 Epoch: 4 | Batch: 600 | Loss: 0.34404323 Epoch: 4 | Batch: 700 | Loss: 0.15679762 Epoch: 4 | Batch: 800 | Loss: 0.29845145 Epoch: 4 | Batch: 900 | Loss: 0.47624302 Epoch: 4 | Batch: 1000 | Loss: 0.4375582 Epoch: 4 | Batch: 1100 | Loss: 0.41988328 Epoch: 4 | Batch: 1200 | Loss: 0.6287129 Epoch: 4 | Batch: 1300 | Loss: 0.3444804 Epoch: 4 | Batch: 1400 | Loss: 0.3052929 Epoch: 4 | Batch: 1500 | Loss: 0.24971539 Epoch: 4 | Batch: 1600 | Loss: 0.95640814 Epoch: 4 | Batch: 1700 | Loss: 0.32827777 Epoch: 4 | Batch: 1800 | Loss: 0.45374364 Epoch: 5 | Batch: 100 | Loss: 0.17481458 Epoch: 5 | Batch: 200 | Loss: 0.15380035 Epoch: 5 | Batch: 300 | Loss: 0.5491959 Epoch: 5 | Batch: 400 | Loss: 0.40911514 Epoch: 5 | Batch: 500 | Loss: 0.4145059 Epoch: 5 | Batch: 600 | Loss: 0.28929818 Epoch: 5 | Batch: 700 | Loss: 0.38863832 Epoch: 5 | Batch: 800 | Loss: 0.77304405 Epoch: 5 | Batch: 900 | Loss: 0.2299521 Epoch: 5 | Batch: 1000 | Loss: 0.357515 Epoch: 5 | Batch: 1100 | Loss: 0.29654604 Epoch: 5 | Batch: 1200 | Loss: 0.31101316 Epoch: 5 | Batch: 1300 | Loss: 0.3234934 Epoch: 5 | Batch: 1400 | Loss: 0.5002061 Epoch: 5 | Batch: 1500 | Loss: 0.19751851 Epoch: 5 | Batch: 1600 | Loss: 0.6368086 Epoch: 5 | Batch: 1700 | Loss: 0.5130822 Epoch: 5 | Batch: 1800 | Loss: 0.41312528 Epoch: 6 | Batch: 100 | Loss: 0.295482 Epoch: 6 | Batch: 200 | Loss: 0.5069757 Epoch: 6 | Batch: 300 | Loss: 0.4230825 Epoch: 6 | Batch: 400 | Loss: 0.39062622 Epoch: 6 | Batch: 500 | Loss: 0.2635874 Epoch: 6 | Batch: 600 | Loss: 0.13531111 Epoch: 6 | Batch: 700 | Loss: 0.56886154 Epoch: 6 | Batch: 800 | Loss: 0.45735866 Epoch: 6 | Batch: 900 | Loss: 0.4042319 Epoch: 6 | Batch: 1000 | Loss: 0.21596207 Epoch: 6 | Batch: 1100 | Loss: 0.4255061 Epoch: 6 | Batch: 1200 | Loss: 0.40704936 Epoch: 6 | Batch: 1300 | Loss: 0.39681357 Epoch: 6 | Batch: 1400 | Loss: 0.413392 Epoch: 6 | Batch: 1500 | Loss: 0.2764989 Epoch: 6 | Batch: 1600 | Loss: 0.14937843 Epoch: 6 | Batch: 1700 | Loss: 0.16362853 Epoch: 6 | Batch: 1800 | Loss: 0.18117441 Epoch: 7 | Batch: 100 | Loss: 0.27730733 Epoch: 7 | Batch: 200 | Loss: 0.250608 Epoch: 7 | Batch: 300 | Loss: 0.28178045 Epoch: 7 | Batch: 400 | Loss: 0.33486652 Epoch: 7 | Batch: 500 | Loss: 0.45808753 Epoch: 7 | Batch: 600 | Loss: 0.4377606 Epoch: 7 | Batch: 700 | Loss: 0.4404745 Epoch: 7 | Batch: 800 | Loss: 0.32960096 Epoch: 7 | Batch: 900 | Loss: 0.22964111 Epoch: 7 | Batch: 1000 | Loss: 0.088504046 Epoch: 7 | Batch: 1100 | Loss: 0.40441728 Epoch: 7 | Batch: 1200 | Loss: 0.34234202 Epoch: 7 | Batch: 1300 | Loss: 0.071227714 Epoch: 7 | Batch: 1400 | Loss: 0.30678958 Epoch: 7 | Batch: 1500 | Loss: 0.12579474 Epoch: 7 | Batch: 1600 | Loss: 0.2306481 Epoch: 7 | Batch: 1700 | Loss: 0.4120247 Epoch: 7 | Batch: 1800 | Loss: 0.5681459 Epoch: 8 | Batch: 100 | Loss: 0.09772281 Epoch: 8 | Batch: 200 | Loss: 0.4902591 Epoch: 8 | Batch: 300 | Loss: 0.2741972 Epoch: 8 | Batch: 400 | Loss: 1.0104656 Epoch: 8 | Batch: 500 | Loss: 0.29213688 Epoch: 8 | Batch: 600 | Loss: 0.1541148 Epoch: 8 | Batch: 700 | Loss: 0.10639417 Epoch: 8 | Batch: 800 | Loss: 0.20356439 Epoch: 8 | Batch: 900 | Loss: 0.33703053 Epoch: 8 | Batch: 1000 | Loss: 0.10546577 Epoch: 8 | Batch: 1100 | Loss: 0.23580535 Epoch: 8 | Batch: 1200 | Loss: 0.47704467 Epoch: 8 | Batch: 1300 | Loss: 0.24450986 Epoch: 8 | Batch: 1400 | Loss: 0.11596918 Epoch: 8 | Batch: 1500 | Loss: 0.2624431 Epoch: 8 | Batch: 1600 | Loss: 0.25802615 Epoch: 8 | Batch: 1700 | Loss: 0.31364802 Epoch: 8 | Batch: 1800 | Loss: 0.51298916 Epoch: 9 | Batch: 100 | Loss: 0.29961824 Epoch: 9 | Batch: 200 | Loss: 0.3079228 Epoch: 9 | Batch: 300 | Loss: 0.30956823 Epoch: 9 | Batch: 400 | Loss: 0.27802593 Epoch: 9 | Batch: 500 | Loss: 0.8724224 Epoch: 9 | Batch: 600 | Loss: 0.32512844 Epoch: 9 | Batch: 700 | Loss: 0.21468031 Epoch: 9 | Batch: 800 | Loss: 0.1498065 Epoch: 9 | Batch: 900 | Loss: 0.24437377 Epoch: 9 | Batch: 1000 | Loss: 0.2752638 Epoch: 9 | Batch: 1100 | Loss: 0.35067338 Epoch: 9 | Batch: 1200 | Loss: 0.51140827 Epoch: 9 | Batch: 1300 | Loss: 0.0720738 Epoch: 9 | Batch: 1400 | Loss: 0.13806337 Epoch: 9 | Batch: 1500 | Loss: 0.15848812 Epoch: 9 | Batch: 1600 | Loss: 0.24589156 Epoch: 9 | Batch: 1700 | Loss: 0.18231271 Epoch: 9 | Batch: 1800 | Loss: 0.18981127 Epoch: 10 | Batch: 100 | Loss: 0.43327972 Epoch: 10 | Batch: 200 | Loss: 0.47614577 Epoch: 10 | Batch: 300 | Loss: 0.2775543 Epoch: 10 | Batch: 400 | Loss: 0.2792448 Epoch: 10 | Batch: 500 | Loss: 0.21499425 Epoch: 10 | Batch: 600 | Loss: 0.30326745 Epoch: 10 | Batch: 700 | Loss: 0.4551992 Epoch: 10 | Batch: 800 | Loss: 0.42125818 Epoch: 10 | Batch: 900 | Loss: 0.11093692 Epoch: 10 | Batch: 1000 | Loss: 0.33124807 Epoch: 10 | Batch: 1100 | Loss: 0.29050675 Epoch: 10 | Batch: 1200 | Loss: 0.24535269 Epoch: 10 | Batch: 1300 | Loss: 0.1671666 Epoch: 10 | Batch: 1400 | Loss: 0.2022452 Epoch: 10 | Batch: 1500 | Loss: 0.52488315 Epoch: 10 | Batch: 1600 | Loss: 0.16817103 Epoch: 10 | Batch: 1700 | Loss: 0.2748699 Epoch: 10 | Batch: 1800 | Loss: 0.5236169

···

HGuillemet commented 1 year ago

so how to use SequentialImpl for a list of user_defined layer or model?

You cannot. There is no way to have libtorch call a forward method implemented in Java. That's why I already advise you several times to use a Java alternative to Sequential like this or this.

HGuillemet commented 1 year ago

Your SeqNow module has no parameters to optimize and no sub-module. You should either directly pass the SequentialImpl module to the optimizer, or register_module("seqs", seqs) in SeqNow constructor.

mullerhai commented 1 year ago

thanks, add i use the same way solve it ··· class SeqNow() extends Module { var seqs = new SequentialImpl() var fc4 = new LinearImpl(784, 64)

var relu = new ReLUImpl()
val dropOpt = new DropoutOptions()
var drop = new DropoutImpl(0.55)
drop.train(true)
var fc5 = new LinearImpl(64, 32)
var relu2 = new ReLUImpl()
var fc6 = new LinearImpl(32, 10)
val log_softmax = new LogSoftmaxImpl(1)
register_module("fc4",fc4 )
register_module("relu",relu )
register_module("drop",drop )
register_module("fc5",fc5 )
register_module("relu2",relu2 )
register_module("fc6",fc6 )
register_module("log_softmax",log_softmax )
seqs.push_back(fc4)
seqs.push_back(relu)
seqs.push_back(drop)
seqs.push_back(fc5)
seqs.push_back(relu2)
seqs.push_back(fc6)
seqs.push_back(log_softmax)
def forward(xl: Tensor): Tensor = {
  var x = xl.reshape(xl.size(0), 784)
  x = seqs.forward(x)
  x
}

}

···

HGuillemet commented 1 year ago

You just need to register seqs. fc4, relu, ... are already submodules of seqs after the push_back.

mullerhai commented 1 year ago

register_module("seqs", seqs)

I want to use the reshape layer or view layer in SequentialImpl ,but not found , so I want to know if we have the method to add reshape layer in model middle postion?

mullerhai commented 1 year ago

You just need to register seqs. fc4, relu, ... are already submodules of seqs after the push_back.

yes ,you say is correct

mullerhai commented 1 year ago

by the way , in the new version ,does ModuleListImpl ModuleDictImpl could perfect implement for really use in coding? thanks

HGuillemet commented 1 year ago

I want to use the reshape layer or view layer in SequentialImpl ,but not found , so I want to know if we have the method to add reshape layer in model middle postion?

There is not such modules in libtorch. You'll have to do the reshape in a forward method of your own, like you did above.

HGuillemet commented 1 year ago

by the way , in the new version ,does ModuleListImpl ModuleDictImpl could perfect implement for really use in coding? thanks

No tested, but I think it's usable.

mullerhai commented 1 year ago

by the way , in the new version ,does ModuleListImpl ModuleDictImpl could perfect implement for really use in coding? thanks

No tested, but I think it's usable.

next day I will test them ,then give you answer

mullerhai commented 1 year ago

by the way , in the new version ,does ModuleListImpl ModuleDictImpl could perfect implement for really use in coding? thanks

No tested, but I think it's usable.

HI @HGuillemet , In my opinoin, the ModuleListImpl ModuleDictImpl two class could usable only a part not all, first them could organize the layer , I just print them layer constructor

JavaCPP_torch_0003a_0003ann_0003a_0003aModule(
  (fc4): torch::nn::Linear(in_features=784, out_features=64, bias=true)
  (relu): torch::nn::ReLU()
  (drop): torch::nn::Dropout(p=0.55, inplace=false)
  (fc5): torch::nn::Linear(in_features=64, out_features=32, bias=true)
  (relu2): torch::nn::ReLU()
  (fc6): torch::nn::Linear(in_features=32, out_features=10, bias=true)
  (log_softmax): torch::nn::LogSoftmax(dim=1)
  (seqs): torch::nn::ModuleList(
    (0): torch::nn::Linear(in_features=784, out_features=64, bias=true)
    (1): torch::nn::ReLU()
    (2): torch::nn::Dropout(p=0.55, inplace=false)
    (3): torch::nn::Linear(in_features=64, out_features=32, bias=true)
    (4): torch::nn::ReLU()
    (5): torch::nn::Linear(in_features=32, out_features=10, bias=true)
    (6): torch::nn::LogSoftmax(dim=1)
  )
)

JavaCPP_torch_0003a_0003ann_0003a_0003aModule(
  (seqs): torch::nn::ModuleDict(
    (fc4): torch::nn::Linear(in_features=784, out_features=64, bias=true)
    (relu): torch::nn::ReLU()
    (drop): torch::nn::Dropout(p=0.55, inplace=false)
    (fc5): torch::nn::Linear(in_features=64, out_features=32, bias=true)
    (relu2): torch::nn::ReLU()
    (fc6): torch::nn::Linear(in_features=32, out_features=10, bias=true)
    (log_softmax): torch::nn::LogSoftmax(dim=1)
  )
)

but we can not invoke the forward method ,because in the ModuleListImpl ModuleDictImpl container all element is Module class ,Module class doesn't have forward method no more, If I want to convert each element to it origin layer type, also meet error , I don't know how to forward them element layer in ModuleListImpl ModuleDictImpl , if convinient please give me some example to do this. thanks

··· class DictNow() extends Module { var fc4 = new LinearImpl(784, 64) var relu = new ReLUImpl() val dropOpt = new DropoutOptions() var drop = new DropoutImpl(0.55) drop.train(true) var fc5 = new LinearImpl(64, 32) var relu2 = new ReLUImpl() var fc6 = new LinearImpl(32, 10) val log_softmax = new LogSoftmaxImpl(1) // var vector = new StringSharedModuleVector() var arrayName= ArrayString var arrayModule= ArrayModule var subDict = new StringSharedModuleDict() subDict.insert("fc4",fc4) subDict.insert("relu",relu) subDict.insert("drop",drop) subDict.insert("fc5",fc5) subDict.insert("relu2",relu2) subDict.insert("fc6",fc6) subDict.insert("log_softmax",log_softmax)

var seqs = new ModuleDictImpl(subDict)
register_module("seqs", seqs)
import org.bytedeco.pytorch.functions._
def forward(xl: Tensor): Tensor = {
  var x = xl.reshape(xl.size(0), 784)
  //      x = seqs.forward(x)
  arrayName.foreach(ele =>{
   x= seqs.get(ele).asInstanceOf[AnyModule].forward(x)
  })

// var count = 0 // var it = seqs.begin // while ( { // !it.equals(seqs.end) // }){ //// seqs.get(1).apply(NamedModuleApplyFunction) // x = seqs.get(count).asInstanceOf[AnyModule].forward(x) // count +=1 // } x } }

class ListNow() extends Module { var seqs = new ModuleListImpl() var fc4 = new LinearImpl(784, 64) var relu = new ReLUImpl() val dropOpt = new DropoutOptions() var drop = new DropoutImpl(0.55) drop.train(true) var fc5 = new LinearImpl(64, 32) var relu2 = new ReLUImpl() var fc6 = new LinearImpl(32, 10) val log_softmax = new LogSoftmaxImpl(1)

    register_module("fc4",fc4 )
    register_module("relu",relu )
    register_module("drop",drop )
    register_module("fc5",fc5 )
    register_module("relu2",relu2 )
    register_module("fc6",fc6 )
    register_module("log_softmax",log_softmax )
register_module("seqs", seqs)
seqs.push_back(fc4)
seqs.push_back(relu)
seqs.push_back(drop)
seqs.push_back(fc5)
seqs.push_back(relu2)
seqs.push_back(fc6)
seqs.push_back(log_softmax)
var arrayName= Array[String]("fc4","relu","drop","fc5","relu2","fc6","log_softmax")
var arrayModule= Array[Module](fc4,relu,drop,fc5,relu2,fc6,log_softmax)
def forward(xl: Tensor): Tensor = {
  var x = xl.reshape(xl.size(0), 784)
  var count = 0
  arrayModule.foreach(ele =>{
    val cla = ele.asInstanceOf[AnyRef].getClass

    println(s"count ${count}")
    val module = seqs.get(count)
    x =module.asInstanceOf[LinearImpl].forward(x)
    count +=1
    println(s"count2 ${count}")
  })

// x = seqs.forward(x) // var it = seqs.begin // while ( { // !it.equals(seqs.end) // }){ //// seqs.get(1).apply(NamedModuleApplyFunction) // } x } }

···

mullerhai commented 1 year ago

I also found some convertion make me confuse, example like LinearImpl could convert to Module,but Module can not convert to LinearImpl Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class org.bytedeco.pytorch.LinearImpl (org.bytedeco.pytorch.Module and org.bytedeco.pytorch.LinearImpl are in unnamed module of loader 'app') at SimpleMNIST$ListNow.$anonfun$forward$2(hell.scala:121)

mullerhai commented 1 year ago

but if in ModuleList ModuleDict forward method write code like that ,them can run perfectly.


    def forward(xl: Tensor): Tensor = {
      var x = xl.reshape(xl.size(0), 784)
      x = relu.forward(fc4.forward(x.reshape(x.size(0), 784)))
      x = drop.forward(x)
      x = relu2.forward(fc5.forward(x))
      x = log_softmax.forward(fc6.forward(x))
      x
    }

but as you know ,we want to foreach module list or dict layer element like python pytorch yield to invoke each layer forward method,it will very easy to write

mullerhai commented 1 year ago

so most important,how to foreach module list or dict layer element like python pytorch yield to invoke each layer forward method for ModuleList ModuleDict,

mullerhai commented 1 year ago

like you see ,in python ,we use ModuleList for a range Layer, like that

class AutomaticFeatureInteractionModel(torch.nn.Module):
    """
    A pytorch implementation of AutoInt.

    Reference:
        W Song, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks, 2018.
    """

    def __init__(self, field_dims, embed_dim, atten_embed_dim, num_heads, num_layers, mlp_dims, dropouts, has_residual=True):
        super().__init__()
       self.self_attns = torch.nn.ModuleList([
            torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers)
        ])

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """

        for self_attn in self.self_attns:
            cross_term, _ = self_attn(cross_term, cross_term, cross_term)

if in javacpp pytorch, how to coding that

HGuillemet commented 1 year ago

I also found some convertion make me confuse, example like LinearImpl could convert to Module,but Module can not convert to LinearImpl Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class

There is no reason for this. I suspect this is a problem related to Scala. Is it possible that you hit this bug ?

HGuillemet commented 1 year ago

Note that if you have a Java module, for instance:

LinearImpl linear = new LinearImpl(3,3);

you pass it to libtorch and get it back:

  ModuleListImpl list = new ModuleListImpl();
  list.push_back(linear);
  Module m = list.get(0);

Module m you are getting back is NOT linear. It's not even an instance of LinearImpl. This is an new instance of Java Module class that wraps the C++ Module instance returned by get:

      System.err.println(linear == m); // false
      System.err.println(linear.equals(m)); // false
      System.err.println(m instanceof LinearImpl); // false
      System.err.println(linear.address() == m.address()); // false
      System.err.println(linear.asModule() == m); // false
      System.err.println(linear.asModule().equals(m)); // true
      System.err.println(linear.asModule().address() == m.address()); // true

There is nothing we can do about this.

HGuillemet commented 1 year ago

Also even in C++, you cannot directly call forward on a items of ModuleList.

mullerhai commented 1 year ago

Note that if you have a Java module, for instance:

LinearImpl linear = new LinearImpl(3,3);

you pass it to libtorch and get it back:

  ModuleListImpl list = new ModuleListImpl();
  list.push_back(linear);
  Module m = list.get(0);

Module m you are getting back is NOT linear. It's not even an instance of LinearImpl. This is an new instance of Java Module class that wraps the C++ Module instance returned by get:

      System.err.println(linear == m); // false
      System.err.println(linear.equals(m)); // false
      System.err.println(m instanceof LinearImpl); // false
      System.err.println(linear.address() == m.address()); // false
      System.err.println(linear.asModule() == m); // false
      System.err.println(linear.asModule().equals(m)); // true
      System.err.println(linear.asModule().address() == m.address()); // true

There is nothing we can do about this.

oh , I think ,if the Module class just is java wrapper obj, and it cannot convert back to real layer obj to invoke forward method, so what the meanning for ModuleDict and ModuleList in pytorch? the Module need forward method or must has anyway to invoke element layer forward method , how do you think?

mullerhai commented 1 year ago

I also found some convertion make me confuse, example like LinearImpl could convert to Module,but Module can not convert to LinearImpl Exception in thread "main" java.lang.ClassCastException: class org.bytedeco.pytorch.Module cannot be cast to class

There is no reason for this. I suspect this is a problem related to Scala. Is it possible that you hit this bug ?

I don't think so, maybe in java the Module cannot convert back to layer obj

mullerhai commented 1 year ago

give you one important example , when layer organize a array , they have some father class Module, then they lost forward method? can not image that bad things!!! do you think this rule is legal?


      var fc4 = new LinearImpl(784, 64)
      var relu = new ReLUImpl()
      val dropOpt = new DropoutOptions()
      var drop = new DropoutImpl(0.55)
      drop.train(true)
      var fc5 = new LinearImpl(64, 32)
      var relu2 = new ReLUImpl()
      var fc6 = new LinearImpl(32, 10)
      val log_softmax = new LogSoftmaxImpl(1)

 var arrayModule= Array(fc4,relu,drop,fc5,relu2,fc6,log_softmax)   // Module Array
      arrayModule.foreach(ele=> {
           ele.forward() // no  ,we need it !!!
          println(ele.asModule().equals(fc4.asModule()))
//          ele.asModule().asInstanceOf[LinearImpl]
        })
HGuillemet commented 1 year ago

We could probably add a constructor for native modules to downcast from Module (similarly to what must be done in C+):

  Module m = list.get(0);
  LinearImpl linear = new LinearImpl(m);
  Tensor output = linear.forward(input);

What do you think ?

mullerhai commented 1 year ago

We could probably add a constructor for native modules to downcast from Module (similarly to what must be done in C+):

  Module m = list.get(0);
  LinearImpl linear = new LinearImpl(m);
  Tensor output = linear.forward(input);

What do you think ?

I think we need only one step ,just add forward method to Module class like Sequential class , only that we could use Module array for layer forward method yield perfect, ··· Module m = list.get(0); Tensor output = m.forward(input); ···

if we do not , I only use switch case judge Layer type then process it , it not grace .like this ···

class ModuleListYieldNow() extends Module { var seqs = new ModuleListImpl() var fc4 = new LinearImpl(784, 64) var relu = new ReLUImpl() val dropOpt = new DropoutOptions() var drop = new DropoutImpl(0.55) drop.train(true) var fc5 = new LinearImpl(64, 32) var relu2 = new ReLUImpl() var fc6 = new LinearImpl(32, 10) val log_softmax = new LogSoftmaxImpl(1) register_module("seqs", seqs) seqs.push_back(fc4) seqs.push_back(relu) seqs.push_back(drop) seqs.push_back(fc5) seqs.push_back(relu2) seqs.push_back(fc6) seqs.push_back(log_softmax) var arrayName= ArrayString var arrayModule= Array(fc4,relu,drop,fc5,relu2,fc6,log_softmax) var tupleModule= Tuple7(fc4,relu,drop,fc5,relu2,fc6,log_softmax) var arrayLayerClass = Tuple7( classOf[LinearImpl],classOf[ReLUImpl],classOf[DropoutImpl],classOf[LinearImpl],classOf[ReLUImpl],classOf[LinearImpl],classOf[LogSoftmaxImpl]) def forward(xl: Tensor): Tensor = { var x = xl.reshape(xl.size(0), 784) var cnt =1 val iter = arrayModule.iterator // val iter = tupleModule.productIterator // val iterClass = arrayLayerClass.productIterator while(iter.hasNext){ val layer = iter.next() layer match { case layer:LinearImpl => x = layer.asInstanceOf[LinearImpl].forward(x) case layer:ReLUImpl => x = layer.asInstanceOf[ReLUImpl].forward(x) case layer:DropoutImpl=> x = layer.asInstanceOf[DropoutImpl].forward(x) case layer:LogSoftmaxImpl => x = layer.asInstanceOf[LogSoftmaxImpl].forward(x) } // if(layer.isInstanceOf[LinearImpl]){ // x = layer.asInstanceOf[LinearImpl].forward(x) // }else if(layer.isInstanceOf[ReLUImpl]){ // x = layer.asInstanceOf[ReLUImpl].forward(x) // }else if(layer.isInstanceOf[DropoutImpl]){ // x = layer.asInstanceOf[DropoutImpl].forward(x) // }else{ // x = layer.asInstanceOf[LogSoftmaxImpl].forward(x) // } cnt +=1 } x } } ···

HGuillemet commented 1 year ago

Yes, either you use some Java list or array of Module, and you can cast normally to subclasses, like you do here with the switch. Or you want to use a torchlib structure like ModuleList, but then we cannot cast as normal, we need this extra step (new LinearImpl(m)). Sequential is different since the C++ implementation knows about the classes of modules it contains and is able to chain the forward calls dynamically.

mullerhai commented 1 year ago

Yes, either you use some Java list or array of Module, and you can cast normally to subclasses, like you do here with the switch. Or you want to use a torchlib structure like ModuleList, but then we cannot cast as normal, we need this extra step (new LinearImpl(m)). Sequential is different since the C++ implementation knows about the classes of modules it contains and is able to chain the forward calls dynamically.

maybe hard for you to select the best way solve it, so now just follow you way like new LinearImpl(m) ,thanks

HGuillemet commented 1 year ago

The first option already works. So if it's enough please use it. The second option with new LinearImpl(m) needs developments.

mullerhai commented 1 year ago

The first option already works. So if it's enough please use it. The second option with new LinearImpl(m) needs developments.

thanks , need the second option,

HGuillemet commented 1 year ago

This issue is finally addressed by PR bytedeco/javacpp#700 Once merged, you will be able to do:

  t2 = new LinearImpl(m).forward(t)

where m is an instance of Module returned by torchlib, for instance by:

m = list.get(0);

where list is a ModuleList. Of course m must be a Linear module.

mullerhai commented 1 year ago

This issue is finally addressed by PR bytedeco/javacpp#700 Once merged, you will be able to do:

  t2 = new LinearImpl(m).forward(t)

where m is an instance of Module returned by torchlib, for instance by:

m = list.get(0);

where list is a ModuleList. Of course m must be a Linear module.

perfect ,now javacpp pytorch is forward to the useable tools for java /scala /jvm,waiting for PR merge,thanks @HGuillemet

saudet commented 1 year ago

That pull request has been merged, so this should be working now!

mullerhai commented 1 year ago

That pull request has been merged, so this should be working now!

perfect,thanks,waiting 1.5.10 version release publish to mvn repos