Install the latest Nightly build of PyTorch.
Then, build this repo
# Make sure the right llvm-config is in your PATH
python setup.py install
python setup.py test
This package transparently hooks into PyTorch's JIT, so the same tooling is applicable (see @torch.jit.script
, torch.jit.trace
and graph_for
). See below for an example.
import torch
import torch_tvm
torch_tvm.enable()
# The following function will be compiled with TVM
@torch.jit.script
def my_func(a, b, c):
return a * b + c
To disable the JIT hooks, use torch_tvm.disable()
.
register.cpp
: Sets up pybind bindings and invokes the registration of a TVM backend.compiler.{h,cpp}
: Main logic to compile a PyTorch JIT graph with TVM.operators.{h,cpp}
: Location of mapping from JIT IR to TVM operators.All options are available as keyword arguments in the enable
function exposed by torch_tvm
.
The optimization level, device type, device and host compilation targets are all exposed directly from TVM.
torch_tvm.enable(
opt_level=3,
device_type="cpu",
device="llvm",
host="llvm")
First, ensure the operator is registered with Relay.
Then, register a map from PyTorch symbols to a Relay CallNode
with RegisterTVMOperator
.
This can be done in any compilation unit provided it is linked into the final torch_tvm
library.
See torch_tvm/operators.cpp
for examples.
RegisterTVMOperator reg_relu({
{Symbol::fromQualString("aten::relu"),
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
auto op = tvm::relay::Op::Get("nn.relu");
return tvm::relay::CallNode::make(op, inputs, tvm::Attrs(), {});
}},
});
If the PyTorch function can be fully converted to Relay, it is possible to extract the expression itself
using torch_tvm.to_relay(func, inputs)
. Example inputs must be passed in to calculate type information.
def add(a, b, c):
return a + b + c
# via tracing
relay_graph = torch_tvm.to_relay(add, inputs)
@torch.jit.script
def mul(a, b, c):
return a * b * c
# via script
relay_graph = torch_tvm.to_relay(mul, inputs)
Note that not all functions can be converted to Relay in their entirety and will raise exceptions if expression extraction is attempted. To solve this isse, simply refactor the function.
Below, in order, is a prioritized list of tasks for this repository.
torch.ops.tvm.*
set_input
tvm/include/tvm/runtime/device_api.h