Tiramisu-Compiler / tiramisu_pytorch

Integration of Tiramisu (Compiler) into PyTorch
26 stars 1 forks source link

pytorch_tiramisu

pytorch_tiramisu is a python package that adds Tiramisu Compiler as a compiler backend to PyTorch Deep Learning Framework.

pytorch_tiramisu offers two modes of utilization:

  1. Mode 1: The first mode is dedicated to non expert users that want to benefit from directly laveraging the compiler stack. The user will be using pre-compiled operators. (This mode is only available for CPUs). This mode does not require the installation of Tiramisu.

  2. Mode 2: If you want to register more operators or apply some optimizations that are not yet supported by the package, you can install tiramisu and fully pass all the compiler stack, as the figure above illustrates.

Build & Install

  1. Install the latest Nightly build of PyTorch. You can choose to install it from source for more efficient acceleration follow these instructions.
  2. Install pytorch_tiramisu: We recommend to install the package from source, since it is still a research project in its infancy.

Using pypi

pip install pytorch_tiramisu

Install from Source

Tests & Tutorials

You can test the installation by running the following code:

import torch 
import pytorch_tiramisu as pt 
pt.enable(jit=True) 

Take a look at one of our Jupyter notebooks to quickly try different features and deep learning models:

Usage

The following function will be compiled with Tiramisu

@torch.jit.script def relu_(a): return F.relu(a)

* Otherwise, pytorch_tiramisu.compile(model) can be used to perform the compilation of the deep learning model prior to running the final graph execution. 

import torch import torch.nn.functional as F import pytorch_tiramisu as pt pt.enable(jit=False)

class Net(nn.Module): def init(self): super(Net, self).init() self.fc = nn.Linear(256, 10)

def forward(self, x):
    x = self.fc3(x)
    return x

model = Net() a = torch.randn(1, 256) generated = pt.compile(model(a)) # Execute an optimization pass and generate the operators. pt.execute(generated)