bytedeco / javacpp-presets

The missing Java distribution of native C++ libraries
Other
2.65k stars 737 forks source link

[Pytorch] how to create & use AnyModule object in new version ? #1399

Closed mullerhai closed 1 year ago

mullerhai commented 1 year ago

HI, now I try to use AnyModule in new pytorch-javacpp version,but I found it can not smooth use, maybe it need some ptr param to create it , @HGuillemet @saudet

 class AnyNet() extends AnyModule { // Construct and register two Linear submodules.

    var fc1 = new LinearImpl(784, 64)
//    register_module("fc1", fc1)
    var fc2 = new LinearImpl(64, 32)
//    register_module("fc2", fc2)
    var fc3 = new LinearImpl(32, 10)
//    register_module("fc3", fc3)

    // Implement the Net's algorithm.
    override def forward(xl: Tensor): Tensor = { // Use one of many tensor manipulation functions.
      var x = xl
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x, 0.5, true)
      x = relu(fc2.forward(x))
      x = log_softmax(fc3.forward(x), 1)
      x
    }
  }
  class SequentialAnyModuleNow() extends Module{
    var seq = new SequentialImpl()
    val anyNet = new AnyNet()
    seq.push_back(anyNet)
    register_module("seqs", seq)
    def forward(xl: Tensor): Tensor = {
      var x = xl.reshape(xl.size(0), 784)
      var cnt =1
      x= seq.forward(x)
      x
  }}

console error


Exception in thread "main" java.lang.RuntimeException: Cannot call ptr() on an empty AnyModule
Exception raised from ptr at /Users/runner/work/javacpp-presets/javacpp-presets/pytorch/cppbuild/macosx-x86_64/pytorch/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h:304 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>) + 81 (0x108baa481 in libc10.dylib)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 197 (0x108ba8c75 in libc10.dylib)
frame #2: torch::nn::SequentialImpl::push_back(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, torch::nn::AnyModule) + 884 (0x151986de4 in libjnitorch.dylib)
frame #3: torch::nn::SequentialImpl::push_back(torch::nn::AnyModule) + 72 (0x1519873e8 in libjnitorch.dylib)
frame #4: Java_org_bytedeco_pytorch_SequentialImpl_push_1back__Lorg_bytedeco_pytorch_AnyModule_2 + 267 (0x15198730b in libjnitorch.dylib)
frame #5: 0x0 + 4598301680 (0x1121483f0 in ???)

    at org.bytedeco.pytorch.SequentialImpl.push_back(Native Method)
    at SimpleMNIST$SequentialAnyModuleNow.<init>(hell.scala:261)
    at SimpleMNIST$.main(hell.scala:386)
    at SimpleMNIST.main(hell.scala)

similar issue
https://github.com/pytorch/pytorch/pull/34208 https://github.com/pytorch/pytorch/pull/17552 https://discuss.pytorch.org/t/modify-nn-anymodule-embedded-in-nn-sequential/100002

mullerhai commented 1 year ago

SequentialImpl can push_back AnyModule obj ,but when put anymodule obj meet error

  /** Adds a new {@code Module} to the {@code Sequential} container, moving or copying it
   *  into a {@code shared_ptr} internally. This method allows passing value types,
   *  and letting the container deal with the boxing. This means you can write
   *  {@code Sequential(Module(3, 4))} instead of
   *  {@code Sequential(std::make_shared<Module>(3, 4))}. */

  /** Adds a new named {@code Module} to the {@code Sequential} container, moving or copying
   *  it into a {@code shared_ptr} internally. This method allows passing value types,
   *  and letting the container deal with the boxing. */

  /** Unwraps the contained module of a {@code ModuleHolder} and adds it to the
   *  {@code Sequential}. */

  /** Unwraps the contained named module of a {@code ModuleHolder} and adds it to the
   *  {@code Sequential}. */

  /** Iterates over the container and calls {@code push_back()} on each value. */

  /** Adds a type-erased {@code AnyModule} to the {@code Sequential}. */
  public native void push_back(@ByVal AnyModule any_module);

  public native void push_back(@StdString BytePointer name, @ByVal AnyModule any_module);
  public native void push_back(@StdString String name, @ByVal AnyModule any_module);
mullerhai commented 1 year ago

··· class SequentialAnyModuleNow() extends Module{

val anyNet = new AnyNet()
println(s"anyNet . ${anyNet.is_empty()}")

···

console log

anyNet . true

why anyModule object is empty is true ? I has fill with three linnear layer

HGuillemet commented 1 year ago

AnyModule is not meant to be subclassed. It's not a solution for adding modules defined in Java to a Sequential. AnyModule is the C++ trick that allows Sequential to work: it is a container for a unknown module and you can call forward with any type and number of arguments and it will dynamically call the forward function of the contained module. But it cannot contain a module with a Java forward method. If you have custom modules in Java/Scala you must call their forward method from Java/Scala. What are you trying to do exacly ? Why you don't use something like this, instead of trying to use Sequential ? (in Java here, I don't know Scala)

abstract class AnyNet extends Module {
  public abstract Tensor forward(Tensor x);
}

class MyNet extends AnyNet {
    final LinearImpl fc1, fc2, fc3;
    MyNet() {
        fc1 = new LinearImpl(784, 64)
        register_module("fc1", fc1)
        fc2 = new LinearImpl(64, 32)
        register_module("fc2", fc2)
        fc3 = new LinearImpl(32, 10)
        register_module("fc3", fc3)
    }

    @Override
    public Tensor forward(Tensor x) {
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x, 0.5, true)
      x = relu(fc2.forward(x))
      x = log_softmax(fc3.forward(x), 1)
      return x;
    }
  }

class AnyModuleNow extends Module {    
    final AnyNet net;
    AnyModuleNow(AnyNet net) {
        this.net = net;
        register_module("net", net);
    }
    Tensor forward(Tensor x) {
        x = x.reshape(x.size(0), 784);
        return net.forward(x);
    }
}
mullerhai commented 1 year ago

AnyModule is not meant to be subclassed. It's not a solution for adding modules defined in Java to a Sequential. AnyModule is the C++ trick that allows Sequential to work: it is a container for a unknown module and you can call forward with any type and number of arguments and it will dynamically call the forward function of the contained module. But it cannot contain a module with a Java forward method. If you have custom modules in Java/Scala you must call their forward method from Java/Scala. What are you trying to do exacly ? Why you don't use something like this, instead of trying to use Sequential ? (in Java here, I don't know Scala)

abstract class AnyNet extends Module {
  public abstract Tensor forward(Tensor x);
}

class MyNet extends AnyNet {
    final LinearImpl fc1, fc2, fc3;
    MyNet() {
        fc1 = new LinearImpl(784, 64)
        register_module("fc1", fc1)
        fc2 = new LinearImpl(64, 32)
        register_module("fc2", fc2)
        fc3 = new LinearImpl(32, 10)
        register_module("fc3", fc3)
    }

    @Override
    public Tensor forward(Tensor x) {
      x = relu(fc1.forward(x.reshape(x.size(0), 784)))
      x = dropout(x, 0.5, true)
      x = relu(fc2.forward(x))
      x = log_softmax(fc3.forward(x), 1)
      return x;
    }
  }

class AnyModuleNow extends Module {    
    final AnyNet net;
    AnyModuleNow(AnyNet net) {
        this.net = net;
        register_module("net", net);
    }
    Tensor forward(Tensor x) {
        x = x.reshape(x.size(0), 784);
        return net.forward(x);
    }
}

I want to embed some layer blocks in Sequential , but Sequential can not receive the block who extend with nn.Module ,only can receive AnyModule instance, so I create AnyModule class obj ,then fill it into Sequential ,then meet error

mullerhai commented 1 year ago

so as you know how to embeded layer blocks in Sequential?

HGuillemet commented 1 year ago

Please tell me why you want to use Sequential instead of the other solutions I suggested several times.

mullerhai commented 1 year ago

Sequential

Sequential is most popular layer and block container in pytorch python version, You has successfully implement SequentialImpl in new version ,so I must try the Sequential if could normal use or not in javacpp, you tell me another solution need to convert the layer organize style, not the best choice for me。I just want to same organize layer block them as Sequential like python pytorch

do you remember you tell me create the block ConvBnReluBlock, I need ten ConvBnReluBlocks in one container to train,use Sequential is the best container in python pytorch ,but Sequential not receive the block which extend nn.Module.

import org.bytedeco.javacpp._
import org.bytedeco.pytorch.global.torch
import org.bytedeco.pytorch.global.torch.ScalarType
import org.bytedeco.pytorch.{BatchNorm2dImpl, BatchNormOptions, Conv1dImpl, Conv1dOptions, Conv2dImpl, Conv2dOptions, Module, ReLUImpl, ReLUOptions, Tensor, _}

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.postfixOps

/** *
 *
 * @param inChannels
 * @param outChannels
 * @param kernelSize
 * @param stride
 */
class ConvBnReluBlock(val inChannels: Int, val outChannels: Int, val kernelSize: Int, val stride: Int) extends Module {

  val convOpt = new Conv2dOptions(inChannels, outChannels, new LongPointer(kernelSize))
  convOpt.stride.put(Array[Long](stride, stride): _*)
  convOpt.padding.put(new LongPointer(kernelSize / 2))
  convOpt.bias.put(false)
  var conv = new Conv2dImpl(convOpt)
  register_module("conv", conv)
  val bnOpt = new BatchNormOptions(outChannels)
  var bn = new BatchNorm2dImpl(bnOpt)
  register_module("bn", bn)
  val reluOpt = new ReLUOptions
  reluOpt.inplace.put(true)
  var relu = new ReLUImpl(reluOpt)
  register_module("relu", relu)

  def forward(x: Tensor): Tensor = relu.forward(bn.forward(conv.forward(x)))
}
HGuillemet commented 1 year ago

You still can use Sequential for blocks like ConvBnRelu, since they only contain C++ modules.

I need ten ConvBnReluBlocks in one container to train

Do you mean 10 conv-bn-relu blocks or 10 blocks with different type of layers each time ? If you need 10 conv-bn-relu block, then define your ConvBnRelu class once and re-instantiate it 10 times. I find this more practical than instantiating 10 Sequential and filling them 10 times.

I remind you my other suggestion: define your own Sequential in Java:

package your.package;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.bytedeco.pytorch.Module;

public class Sequential<O> extends Module {

  private final Module[] modules;
  private final Method[] forwardMethods;

  public Sequential(Module... modules) {
    this.modules = modules;
    forwardMethods = new Method[modules.length];
    MODULE:
    for (int i = 0; i < modules.length; i++) {
      for (Method method : modules[i].getClass().getMethods()) {
        if (method.getName().equals("forward")) {
          forwardMethods[i] = method;
          register_module(Integer.toString(i), modules[i]);
          continue MODULE;
        }
      }
      throw new IllegalArgumentException("No forward method found for module " + i);
    }
  }

  public O forward(java.lang.Object... inputs) {
    java.lang.Object output;
    for (int i = 0; ; i++) {
      try {
        output = forwardMethods[i].invoke(modules[i], inputs);
      } catch (IllegalAccessException | InvocationTargetException e) {
        throw new RuntimeException(e);
      }
      if (i == modules.length - 1) break;
      inputs = new java.lang.Object[]{output};
    }
    return (O) output;
  }
}

that you can use like this

Conv2dOptions convOpt = new Conv2dOptions(8, 16, new LongPointer(3, 3));
convOpt.stride().put(new long[]{2, 2});
convOpt.padding().put(new LongPointer(1, 1));
convOpt.bias().put(false);
Conv2dImpl conv = new Conv2dImpl(convOpt);

BatchNormOptions bnOpt = new BatchNormOptions(16);
BatchNorm2dImpl bn = new BatchNorm2dImpl(bnOpt);

ReLUOptions reluOpt = new ReLUOptions();
reluOpt.inplace().put(true);
ReLUImpl relu = new ReLUImpl(reluOpt);

Sequential<Tensor> seq = new Sequential<>(conv, bn, relu); 
Tensor out = seq.forward(in);

For using a module in the C++ sequential, we need this module to be defined at C++ compile time. So the only solution to add Java modules to C++ sequential would be to define in the presets a subclass of Module with a virtual forward function. But we would need one such class for every possible signature of forward. And the number of possible forward signature is infinite. We could only add those for some classical signatures, like taking 1 tensor and returning 1 tensor, and another taking 2 tensors and returning 1 tensor, and a couple of others... but you'll eventually find a not covered case. So I'm not sure it's a good idea, since there are alternatives like discussed above.

mullerhai commented 1 year ago

You still can use Sequential for blocks like ConvBnRelu, since they only contain C++ modules.

I need ten ConvBnReluBlocks in one container to train

Do you mean 10 conv-bn-relu blocks or 10 blocks with different type of layers each time ? If you need 10 conv-bn-relu block, then define your ConvBnRelu class once and re-instantiate it 10 times. I find this more practical than instantiating 10 Sequential and filling them 10 times.

I remind you my other suggestion: define your own Sequential in Java:

package your.package;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.bytedeco.pytorch.Module;

public class Sequential<O> extends Module {

  private final Module[] modules;
  private final Method[] forwardMethods;

  public Sequential(Module... modules) {
    this.modules = modules;
    forwardMethods = new Method[modules.length];
    MODULE:
    for (int i = 0; i < modules.length; i++) {
      for (Method method : modules[i].getClass().getMethods()) {
        if (method.getName().equals("forward")) {
          forwardMethods[i] = method;
          register_module(Integer.toString(i), modules[i]);
          continue MODULE;
        }
      }
      throw new IllegalArgumentException("No forward method found for module " + i);
    }
  }

  public O forward(java.lang.Object... inputs) {
    java.lang.Object output;
    for (int i = 0; ; i++) {
      try {
        output = forwardMethods[i].invoke(modules[i], inputs);
      } catch (IllegalAccessException | InvocationTargetException e) {
        throw new RuntimeException(e);
      }
      if (i == modules.length - 1) break;
      inputs = new java.lang.Object[]{output};
    }
    return (O) output;
  }
}

that you can use like this

Conv2dOptions convOpt = new Conv2dOptions(8, 16, new LongPointer(3, 3));
convOpt.stride().put(new long[]{2, 2});
convOpt.padding().put(new LongPointer(1, 1));
convOpt.bias().put(false);
Conv2dImpl conv = new Conv2dImpl(convOpt);

BatchNormOptions bnOpt = new BatchNormOptions(16);
BatchNorm2dImpl bn = new BatchNorm2dImpl(bnOpt);

ReLUOptions reluOpt = new ReLUOptions();
reluOpt.inplace().put(true);
ReLUImpl relu = new ReLUImpl(reluOpt);

Sequential<Tensor> seq = new Sequential<>(conv, bn, relu); 
Tensor out = seq.forward(in);

For using a module in the C++ sequential, we need this module to be defined at C++ compile time. So the only solution to add Java modules to C++ sequential would be to define in the presets a subclass of Module with a virtual forward function. But we would need one such class for every possible signature of forward. And the number of possible forward signature is infinite. We could only add those for some classical signatures, like taking 1 tensor and returning 1 tensor, and another taking 2 tensors and returning 1 tensor, and a couple of others... but you'll eventually find a not covered case. So I'm not sure it's a good idea, since there are alternatives like discussed above.

I think ,this solution maybe good in javacpp,the SequentialImpl in javacpp is not the same function as Python,this way only todo like your coding style,thanks