apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.81k stars 3.48k forks source link

[JVM] Align Java GraphModule Initialization with Python API #17464

Closed OleehyO closed 1 month ago

OleehyO commented 1 month ago

Java API is still using the outdated initialization method for GraphModule, which has led to issues where the old API no longer works as expected.

This PR updates the Java API for GraphModule initialization to match the simplified method used in the Python API.

Background

In the Python API, GraphModule can be initialized in a more concise way:

gm = graph_executor.GraphModule(lib["default"](dev))

However, the Java API still uses the older approach:

gm = graph_executor.create(graph_json, lib, dev);
gm.load_params(params);

The old API is not only more verbose (two additional files need to be saved and loaded), but also appears to no longer be functional as expected.


Here is an example of deploying DepthAnything where the ONNX frontend of Relay is used to create the IRModule before compiling and exporting. During deployment, the old initialization method seems no longer works (top: new method, bottom: old method):

截屏2024-10-14 16 38 57 截屏2024-10-14 16 40 46

To address these, this PR introduces a new initialization method for GraphModule in the Java API, aligning it with the simplified Python API.

Usage Example

Java:

Device dev = Device.cpu();
Module lib = Module.load(libPath);
Module mod = lib.getFunction("default").call(dev).asModule();  // The current Java API does not support calling Device types in Functions
GraphModule gm = new GraphModule(mod, cpuDev);

equivalent Python:

dev = tvm.device('cpu')
lib = tvm.runtime.load_module(libPath)
gm = graph_executor.GraphModule(lib["default"](dev))