bytedeco / javacpp-presets

The missing Java distribution of native C++ libraries
Other
2.65k stars 737 forks source link

[pytorch] How to use SequentialImpl StringAnyModuleDict AnyModuleVector these class? #1283

Closed mullerhai closed 10 months ago

mullerhai commented 1 year ago

HI , I write python pytorch code ,use torch.nn.Sequential this object, but in javacpp pytorch use Sequential is complex, this object not have [forward] method ,and it constructor need StringAnyModuleDict ,but when foreach cannot get the Module element real type ,only see AnyModule, how to get real module type and as the real insert module index invoke the moudle element forward() method

class MultiLayerPerceptron(torch.nn.Module):

    def __init__(self, input_dim, embed_dims, dropout, output_layer=True):
        super().__init__()
        layers = list()
        for embed_dim in embed_dims:
            layers.append(torch.nn.Linear(input_dim, embed_dim))
            layers.append(torch.nn.BatchNorm1d(embed_dim))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(p=dropout))
            input_dim = embed_dim
        if output_layer:
            layers.append(torch.nn.Linear(input_dim, 1))
        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, embed_dim)``
        """
        return self.mlp(x)

in scala

class MultiLayerPerceptron(input_dim:Int, embed_dims:Seq[Int], dropoutPer:Double, output_layer:Boolean=true) extends Module{

  val layers=new ListBuffer[Module]()
  val moduleDict = new StringAnyModuleDict()

  for(embed_dim <- embed_dims){
    var line = register_module(s"line_${embed_dim}", new LinearImpl(input_dim, embed_dim))
    val normOpt = new BatchNormOptions(embed_dim)
    val norm = register_module(s"norm_${embed_dim}", new BatchNorm1dImpl(normOpt))
    val relu = register_module(s"relu_${embed_dim}", new ReLUImpl())
    val dropOpt = new DropoutOptions(dropoutPer)
    val dropout = register_module(s"dropout_${embed_dim}", new DropoutImpl(dropOpt))
    moduleDict.insert(s"line_${embed_dim}",line.asInstanceOf[AnyModule])
    moduleDict.insert(s"norm_${embed_dim}",norm.asInstanceOf[AnyModule])
    moduleDict.insert(s"relu_${embed_dim}",relu.asInstanceOf[AnyModule])
    moduleDict.insert(s"dropout_${embed_dim}",dropout.asInstanceOf[AnyModule])
    layers.+=(line)
    layers.+=(norm)
    layers.+=(relu)
    layers.+=(dropout)
  }
//  Sequential
//  SequentialImpl Sequential
  val mlp = new SequentialImpl(moduleDict)
  def forward(xs:Tensor):Tensor ={
    var x = xs

    val moduleIter: AnyModuleVector.Iterator =  this.mlp.begin()
    val moduleIterEnd: AnyModuleVector.Iterator = this.mlp.end()
    var index = 0
    while(!moduleIter.equals(moduleIterEnd)){
      moduleIter.get().asInstanceOf[BatchNorm1dImpl].forward(x)
      x = mlp.get(index).forward(x)
      moduleIter.increment()
    }
    x
  }
}
saudet commented 1 year ago

Yes, unfortunately, Sequential is not currently usable, see issue https://github.com/bytedeco/javacpp-presets/issues/623#issuecomment-814107543.

HGuillemet commented 1 year ago

Right. Don't try to use Sequential. Here is an example in Java for a Convolution/Batch Normalization/Relu block:

public class ConvBnRelu extends Module {
  final Conv2dImpl conv;
  final BatchNorm2dImpl bn;
  final ReLUImpl relu;

  ConvBnRelu(int inChannels, int outChannels, int kernelSize, int stride) {
    Conv2dOptions convOpt = new Conv2dOptions(inChannels, outChannels, new ExpandingArray2(kernelSize));
    convOpt.stride().put(new long[]{stride, stride});
    convOpt.padding().put(new ExpandingArray2(kernelSize/2));
    convOpt.bias().put(false);
    conv = new Conv2dImpl(convOpt);
    register_module("conv", conv);

    BatchNormOptions bnOpt = new BatchNormOptions(outChannels);
    bn = new BatchNorm2dImpl(bnOpt);
    register_module("bn", bn);

    ReLUOptions reluOpt = new ReLUOptions();
    reluOpt.inplace().put(true);
    relu = new ReLUImpl(reluOpt);
    register_module("relu", relu);
  }

  Tensor forward(Tensor x) {
    return relu.forward(bn.forward(conv.forward(x)));
  }
}
mullerhai commented 1 year ago

@HGuillemet very thank ,I just want to say does exist another way to do like module collect , thanks

HGuillemet commented 1 year ago

Sorry I don't understand your question.

mullerhai commented 1 year ago

Sorry I don't understand your question.

just instead of sequential

mullerhai commented 1 year ago

Right. Don't try to use Sequential. Here is an example in Java for a Convolution/Batch Normalization/Relu block:

public class ConvBnRelu extends Module {
  final Conv2dImpl conv;
  final BatchNorm2dImpl bn;
  final ReLUImpl relu;

  ConvBnRelu(int inChannels, int outChannels, int kernelSize, int stride) {
    Conv2dOptions convOpt = new Conv2dOptions(inChannels, outChannels, new ExpandingArray2(kernelSize));
    convOpt.stride().put(new long[]{stride, stride});
    convOpt.padding().put(new ExpandingArray2(kernelSize/2));
    convOpt.bias().put(false);
    conv = new Conv2dImpl(convOpt);
    register_module("conv", conv);

    BatchNormOptions bnOpt = new BatchNormOptions(outChannels);
    bn = new BatchNorm2dImpl(bnOpt);
    register_module("bn", bn);

    ReLUOptions reluOpt = new ReLUOptions();
    reluOpt.inplace().put(true);
    relu = new ReLUImpl(reluOpt);
    register_module("relu", relu);
  }

  Tensor forward(Tensor x) {
    return relu.forward(bn.forward(conv.forward(x)));
  }
}

If I create ConvBnRelu module layer block , then I found the model Net need a list of ConvBnRelu, if use Sequential or ModuleList maybe easy , but in javacpp pytorch ,I don't how to write the code, need your help , by the way the Sequential and ModuleList and ModuleDict is most important for build complex layer model network ,why we can not support them, thanks

HGuillemet commented 1 year ago

by the way the Sequential and ModuleList and ModuleDict is most important for build complex layer model network ,why we can not support them

Sequential and AnyModule uses mechanisms that are difficult to map in Java, to cope with the fact that the forward method of modules can have any signature.

It also has not been done because it's easy, and more Java-ish, to do without Sequential. For instance to chain 2 ConvBnRelu:

public class Chain extends Module {
  final ConvBnRelu block1 = new ConvBnRelu(3, 10, 3, 1);
  final ConvBnRelu block2 = new ConvBnRelu(10, 20, 3, 3);

  Chain() {
    register_module("block1", block1);
    register_module("block2", block2);
  }

  Tensor forward(Tensor x) {
    return block2.forward(block1.forward(x));
  }
}
saudet commented 1 year ago

Sequential and AnyModule uses mechanisms that are difficult to map in Java, to cope with the fact that the forward method of modules can have any signature.

If that's the only issue, I think simply mapping the forward() function to a few overloads taking Tensor arguments as follows, which covers most use cases, should work:

public native @ByVal Tensor forward(@Const @ByRef Tensor input0);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1, @Const @ByRef Tensor input2);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3);
...
HGuillemet commented 1 year ago

First issue is to map Sequential variadic sequential constructor. This seems difficult. We could rather map push_back by instantiating the template for all standard module, like you do, for instance, for registerModule. And/Or, maybe more generic, we could support AnyModule by also instantiating all constructor for standard modules, and use the methods of Sequential taking AnyModule.

As a side note: I don't think we need register_module() for all standard concrete subclass of Module, since register_module(Module) exists. This line could be removed. But I doubt it will work for Sequential or AnyModule.

Issue 2 is more problematic : we should be able to use Sequential or AnyModule with custom Java module, like ConvBnRelu above. But I don't see how the C++ library could find our forward method.

saudet commented 1 year ago

First issue is to map Sequential variadic sequential constructor. This seems difficult. We could rather map push_back by instantiating the template for all standard module, like you do, for instance, for registerModule. And/Or, maybe more generic, we could support AnyModule by also instantiating all constructor for standard modules, and use the methods of Sequential taking AnyModule.

Now that you mention it, we should be able to map all instances of the AnyModule constructor for all concrete Module types, just like with register_module(). That should work, right?

As a side note: I don't think we need register_module() for all standard concrete subclass of Module, since register_module(Module) exists. This line could be removed. But I doubt it will work for Sequential or AnyModule.

I'd have to look at that bit more again, but there's a reason that's there. It doesn't work if we don't map all instances.

Issue 2 is more problematic : we should be able to use Sequential or AnyModule with custom Java module, like ConvBnRelu above. But I don't see how the C++ library could find our forward method.

From what I understand looking at that again, we don't need to implement forward(). The default one from Sequential at least should work just fine as it is, but like I said, that needs to be tested.

HGuillemet commented 1 year ago

Now that you mention it, we should be able to map all instances of the AnyModule constructor for all concrete Module types, just like with register_module(). That should work, right?

I think it will, but I'm not sure it's a good thing to do it if AnyModule cannot include our Java Module (see below).

I'd have to look at that bit more again, but there's a reason that's there. It doesn't work if we don't map all instances.

You're right, a quick test give me a SIGSEGV when calling register_module if I remove these mapping. An idea why ?

From what I understand looking at that again, we don't need to implement forward(). The default one from Sequential at least should work just fine as it is, but like I said, that needs to be tested.

The aim of Sequential is to call the forward functions in chain. Even if we manage somehow to include our Java class extending Module to the Sequential, how could it find our Java forward method ? It only sees a reference to a Module

saudet commented 1 year ago

It obviously doesn't work for modules implemented in Java, it's only going to work for the ones implemented in C++, but it's better than nothing :)

mullerhai commented 1 year ago

It obviously doesn't work for modules implemented in Java, it's only going to work for the ones implemented in C++, but it's better than nothing :)

I also know Sequential and anyModule are difficult implement in java ,thanks for the contributors do a lot of background work

HGuillemet commented 1 year ago

I have tried to map AnyModule constructors for all standard modules. That works and I can create a Sequential and push_back standard modules. But then there is no way to call the forward method of the Sequential from Java ! I think we are down to reproduce Sequential in Java. Either case-by-case by the user, like shown above with ConvBnRelu (my preference), or with a helper subclass of Module that could use Java introspection to chain the forward methods. Such module would implement an ugly Object[] forward(Object...) or be generic: Sequential<I,O> where I and O are the input and output types for forward method, that is the input type of the first module and the output type of the last.

The preset can provide this helper class, but as @saudet often points it out, this is rather the role of an upper-level software. The lack of mapping for the native sequential module is not functionally blocking.

saudet commented 1 year ago

I have tried to map AnyModule constructors for all standard modules. That works and I can create a Sequential and push_back standard modules. But then there is no way to call the forward method of the Sequential from Java ! I think we are down to reproduce Sequential in Java. Either case-by-case by the user, like shown above with ConvBnRelu (my preference), or with a helper subclass of Module that could use Java introspection to chain the forward methods. Such module would implement an ugly Object[] forward(Object...) or be generic: Sequential<I,O> where I and O are the input and output types for forward method, that is the input type of the first module and the output type of the last.

Could you provide the compiler errors that you get when you try what I wrote above https://github.com/bytedeco/javacpp-presets/issues/1283#issuecomment-1405912098?

HGuillemet commented 1 year ago

I don't know how to translate this as Info but I just tried to add

public native @ByVal Tensor forward(@Const @ByRef Tensor input0);

in Sequential.java after the Parser and before the Generator and it does compile. I could run this example:

    Conv2dImpl impl = new Conv2dImpl(1, 1, new LongPointer(3, 3));
    AnyModule any = new AnyModule(impl);
    SequentialImpl seq = new SequentialImpl();
    seq.push_back("conv", any);
    Tensor in = torch.rand(1, 1, 10, 10);
    seq.forward(in);

However, I'm still unsure if we'd better map the native Sequential or write a Java counterpart that does the same thing but will accept Java modules, and also input and output for any numbers and types. Something like this:

package org.bytedeco.pytorch;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class Sequential<O> extends Module {

    private final Module[] modules;
    private final Method[] forwardMethods;
    public Sequential(Module... modules) {
        this.modules = modules;
        forwardMethods = new Method[modules.length];
        MODULE:
        for (int i = 0; i < modules.length; i++) {
            for (Method method: modules[i].getClass().getMethods()) {
                if (method.getName().equals("forward")) {
                    forwardMethods[i] = method;
                    register_module(Integer.toString(i), modules[i]);
                    continue MODULE;
                }
            }
            throw new IllegalArgumentException("No forward method found for module "+i);
        }
    }

    public O forward(java.lang.Object... inputs) {
        java.lang.Object output;
        for (int i = 0; ; i++) {
            try {
                output = forwardMethods[i].invoke(modules[i], inputs);
            } catch (IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
            if (i == modules.length - 1) break;
            inputs = new java.lang.Object[] { output };
        }
        return (O) output;
    }
}
saudet commented 1 year ago

I don't know how to translate this as Info but I just tried to add

We can probably just pass that as Info.javaText for the forward() function.

However, I'm still unsure if we'd better map the native Sequential or write a Java counterpart that does the same thing but will accept Java modules, and also input and output for any numbers and types. Something like this:

We can add code like that as part of a helper class, sure. We can probably add all that we want as a subclass like org.bytedeco.pytorch.AbstractSequential and an overloaded constructor though. If possible, I wouldn't add another Sequential class just for that.

HGuillemet commented 1 year ago

I would either provide the native Sequential or a Java one, not both, nor one as a subclass of the other, since they do the same thing. My preference is for the Java one since it's compatible with custom Java modules. However, we'd better first understand why register_module(Module) sometimes causes segmentation faults. We need it in the Java Sequential.

saudet commented 1 year ago

I would either provide the native Sequential or a Java one, not both, nor one as a subclass of the other, since they do the same thing. My preference is for the Java one since it's compatible with custom Java modules.

If it's not meant to support the C++ API, but provide a higher-level Java API on top of the C++ API, that should probably be done as part of another module, in another repository. We can add classes like that as part of JavaCV if you like?

However, we'd better first understand why register_module(Module) sometimes causes segmentation faults. We need it in the Java Sequential.

I wasn't aware of any issues with that. You'll need to provide more details to go about it. I can't fix what isn't broken :)

HGuillemet commented 1 year ago

If it's not meant to support the C++ API, but provide a higher-level Java API on top of the C++ API, that should probably be done as part of another module, in another repository. We can add classes like that as part of JavaCV if you like?

Subclassing Module is the way the API is supposed to be used, but the native Sequential cannot chain Java modules. So the Java Sequential is a a way to fix the native sequential. But I agree that it could belong to a higher level software, since it's nothing essential, like said before. JavaCV depends on Pytorch ? Anyway it'd be better in a thinner "jTorch" package dedicated to provide a clean Java API to the preset. But I think we should wait for Pytorch 2 and see what is the future of the C++ API before working on such package.

HGuillemet commented 1 year ago

I wasn't aware of any issues with that. You'll need to provide more details to go about it. I can't fix what isn't broken :)

I dug a bit further and here is what I understood: A Java class like Conv2dImpl is really a C++ shared_ptr<Conv2dImpl>. But it extends the Java class Module which maps the C++ Module. So what is related by an inheritance relationship in one language is not in the other. This is why, for instance, you had to define register_module specializations of all standard modules. But this causes this kind of problem:

Module m = new MyModule(); // some Java subclass of Module
Conv2dImpl c = new Conv2dImpl(1, 1, new LongPointer(3, 3));

register_module("x", m); // => ok
register_module("y", c); // => ok
m = c; 
register_module("z", m); // => SIGSEGV

And that's an issue for any method taking a Module as argument, particularly if the specializations were not generated. It was probably not reported yet because there aren't many of such method.

torch.shiftLeft(torch.cout(), new ReLUImpl()); // => SIGSEGV

Currently the presets also defines classes like Conv2dImplModuleHolder, that I personally do not use, which is meant to represent a shared_ptr<Conv2dImpl>.

My impression is that we should be able to find a more consistent mapping, and that we could hide the shared_ptr tricks of the C++ API. Any comment before I try to propose some fixes ?

saudet commented 1 year ago

I dug a bit further and here is what I understood: A Java class like Conv2dImpl is really a C++ shared_ptr<Conv2dImpl>. But it extends the Java class Module which maps the C++ Module. So what is related by an inheritance relationship in one language is not in the other. This is why, for instance, you had to define register_module specializations of all standard modules. But this causes this kind of problem:

Conv2dImpl isn't a shared_ptr<Conv2dImpl>, although it can refer to one, yes. In any case, you'd encounter the same "problem" trying to do that in C++. That's just how the C++ API of PyTorch is. The only way to make this friendlier is to provide a high-level API on top of that, like DJL or Storch. /cc @frankfliu @sbrunk

Currently the presets also defines classes like Conv2dImplModuleHolder, that I personally do not use, which is meant to represent a shared_ptr<Conv2dImpl>.

I'm not entirely sure what the purpose of ModuleHolder is supposed to be, but you can try to map register_module() for those and see if it behaves more like we'd expect it to when used from Java.

HGuillemet commented 1 year ago

Conv2dImpl isn't a shared_ptr<Conv2dImpl>

Ok. I misundestood indeed.

Each time a C++ function takes a shared_ptr<torch::nn::Module>, the presets will create a method taking a Module and make a shared_ptr<torch::nn::Module> from its pointer. The problem is in fact that if we pass a subclass, like Conv2dImpl, we get a SIGSEGV when it attempts to create the shared_ptr. I guess because torch::nn::Module is virtual. We would need to create a shared_ptr<torch::nn::Conv2dImpl>.

Did I get it right this time ?

HGuillemet commented 1 year ago

Not quite. The SIGSEGV is not due to the shared_ptr creation per se, but to the C-style cast from Conv2dImpl to Module, which cross a virtual inheritance (Cloneable: virtual Module). Has JavaCPP a mechanism to deal with virtual inheritance ?

saudet commented 1 year ago

Has JavaCPP a mechanism to deal with virtual inheritance ?

Not really, no. We need to add methods somewhere to call static_cast and/or dynamic_cast, manually. That's what the asModule() methods are for.

mullerhai commented 1 year ago

Hi, Now in new version pytorch-java do we friendly support SequentialImpl ModuleDict ModuleList ? not only complex model need these in recommend system algorithm, we want to transfer python pytorch code to java code ,not only load python trained pytorch model to java

HGuillemet commented 1 year ago

Not in the 2.0.1 version currently online, but I'm working on a big overhaul of the presets and I'll try to add support for those. However, they will only be able to chain native module, not your custom Java modules. Why not using the solutions suggested above: here or here, which allow to chain Java and native modules together ?