tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
25 stars 5 forks source link
pytorch ttnn

PyTorch 2.0 TTNN Compiler

This project allows to run PyTorch code on Tenstorrent hardware.

Supported Models

The table below summarizes the results of running various ML models through our TTNN compiler. For each model, we track whether the run was successful, the number of operations before and after conversion, the number of to_device and from_device operations, performance metrics, and accuracy.

Model Run Success Torch Ops Before (Unique Ops) Torch Ops Remain (Unique Ops) To/From Device Ops Original Run Time (ms) Compiled Run Time (ms) Accuracy (%)
[Autoencoder (conv)](<docs/models/Autoencoder (conv)>) 🚧 9 (3) 6 (2) 3 1337 314.93 100.0
[Autoencoder (conv)-train](<docs/models/Autoencoder (conv)-train>) 🚧 24 (7) 21 (6) 3 1988.33 1321.0 100.0
[Autoencoder (linear)](<docs/models/Autoencoder (linear)>) 🚧 22 (3) 1 (1) 1 1325.56 812.88 100.0
[Autoencoder (linear)-train](<docs/models/Autoencoder (linear)-train>) 🚧 104 (8) 31 (4) 40 1798.33 6112.67 100.0
BERT 1393 (21) 0 (0) 28 71575.9 49045.6 98.88
Bloom 🚧 1407 (29) 105 (9) 78 78687.4 58629.53 41.09
CLIP 🚧 1396 (30) 632 (20) 357 4798.48 33033.74 94.11
CLIP-train 3943 (44) N/A N/A 25963.6 N/A N/A
DETR 1668 (41) N/A N/A 89729 N/A N/A
DPR 720 (22) N/A N/A 5577.48 N/A N/A
FLAN-T5 20106 (38) N/A N/A 12582.4 N/A N/A
Falcon 🚧 2694 (29) 583 (17) 576 133002 191097.31 95.02
GLPN-KITTI 3074 (30) N/A N/A 126610 N/A N/A
GPT-2 748 (31) N/A N/A 7089.46 N/A N/A
GPTNeo 2761 (36) N/A N/A 14691.9 N/A N/A
[Hand Landmark](<docs/models/Hand Landmark>) N/A N/A N/A 6115.49 N/A N/A
HardNet 🚧 245 (10) 240 (5) 1 4511.35 12918.2 99.99
HardNet-train 🚧 867 (21) 650 (15) 71 12115.8 45096.38 100.0
Llama 🚧 104 (5) 33 (2) 35 242747 167580.53 100.0
MLPMixer 253 (11) N/A N/A 5441.92 N/A N/A
MLPMixer-train 616 (19) N/A N/A 17382.3 N/A N/A
Mnist 🚧 14 (8) 3 (2) 3 3711.39 4745.71 98.88
Mnist-train 🚧 46 (15) 20 (8) 14 3501.9 6950.09 100.0
MobileNetSSD 575 (34) N/A N/A 1142.43 N/A N/A
MobileNetV2 🚧 154 (9) 139 (3) 11 843.1 11656.39 99.98
OPT 4073 (32) N/A N/A 31291.9 N/A N/A
[OpenPose V2](<docs/models/OpenPose V2>) 🚧 155 (7) 98 (4) 50 2817.68 7916.12 93.11
[OpenPose V2-train](<docs/models/OpenPose V2-train>) 🚧 523 (14) 450 (12) 69 9860.72 27751.17 100.0
[Perceiver IO](<docs/models/Perceiver IO>) 🚧 1532 (21) 4 (3) 30 66782 64844.16 99.94
ResNet18 🚧 70 (9) 41 (3) 17 2184.85 5046.11 99.99
ResNet18-train 🚧 241 (19) 197 (14) 31 6564.42 19230.47 100.0
ResNet50 🚧 176 (9) 107 (3) 49 4149.54 17436.06 99.98
ResNet50-train 🚧 616 (19) 524 (14) 71 16070.8 42933.82 100.0
RoBERTa 719 (21) N/A N/A 17300.3 N/A N/A
SegFormer 768 (27) N/A N/A 26060.7 N/A N/A
SegFormer-train 1872 (40) N/A N/A 54648.9 N/A N/A
SqueezeBERT 🚧 16 (9) 4 (2) 4 7448.21 5310.57 100.0
[Stable Diffusion V2](<docs/models/Stable Diffusion V2>) 1883 (32) N/A N/A 2.09198e+06 N/A N/A
U-Net 🚧 68 (6) 49 (4) 19 40201.9 41711.71 100.0
U-Net-train 🚧 236 (15) 205 (12) 27 78688.1 90950.49 100.0
Unet-brain 🚧 68 (6) 49 (4) 19 40387.7 42396.33 N/A
Unet-brain-train 🚧 236 (15) 205 (12) 27 79632.8 85364.46 100.0
Unet-carvana 🚧 67 (5) 49 (4) 18 62921.6 64848.64 100.0
Unet-carvana-train 🚧 232 (13) 202 (11) 26 142637 152138.14 100.0
ViLT 🚧 55 (18) 31 (14) 8 16538.5 19866.4 88.03
Whisper 4294 (19) N/A N/A 243685 N/A N/A
XGLM 🚧 1459 (30) 94 (18) 78 45533 59245.47 87.99
YOLOS 🚧 966 (28) 50 (8) 32 20144.5 48544.56 42.43
YOLOv3 268 (10) N/A N/A 214728 N/A N/A
YOLOv5 🚧 3 (3) 2 (2) 1 31609.8 29510.87 100.0
albert/albert-base-v2 🚧 791 (21) 41 (5) 62 3401.37 32049.73 66.33
albert/albert-base-v2-classification 🚧 779 (21) 28 (3) 63 3249.45 17310.19 -99.53
albert/albert-large-v2 🚧 1547 (21) 101 (5) 170 5327.23 39768.73 17.62
albert/albert-xlarge-v2 🚧 1547 (21) 77 (5) 122 16889.7 42096.27 51.12
albert/albert-xxlarge-v2 🚧 791 (21) 65 (6) 39 42072 38159.16 15.46
codegen 9237 (37) N/A N/A 11740.1 N/A N/A
densenet121 🚧 432 (10) 307 (5) 121 2823.78 20091.98 99.99
densenet161 🚧 572 (10) 407 (5) 161 7373.44 38499.01 99.99
densenet169 🚧 600 (10) 427 (5) 169 3310.82 26175.12 99.99
densenet201 🚧 712 (10) 507 (5) 201 4847.3 55142.83 99.99
distilbert-base-uncased 367 (17) N/A N/A 8361.94 N/A N/A
dla34.in1k 🚧 135 (9) 85 (4) 34 6069.93 8262.23 100.0
dla34.in1k-train 🚧 469 (18) 378 (13) 57 12031.7 32889.43 100.0
ese_vovnet19b_dw.ra_in1k 🚧 111 (12) 75 (5) 27 2065.5 16223.18 99.98
ese_vovnet19b_dw.ra_in1k-train 🚧 360 (25) 277 (16) 60 5283.96 25003.03 100.0
facebook/deit-base-patch16-224 🚧 686 (18) 257 (6) 195 20954 21980.32 97.18
facebook/deit-base-patch16-224-train 🚧 1856 (28) 684 (15) 658 72669.1 50721.55 100.0
ghostnet_100.in1k 🚧 515 (14) 262 (5) 72 752.2 23403.22 99.97
ghostnet_100.in1k-train 🚧 1468 (33) 1044 (24) 190 1627.17 62035.8 100.0
ghostnetv2_100.in1k 🚧 809 (20) 394 (12) 95 1615.6 37644.62 99.93
ghostnetv2_100.in1k-train 2126 (41) N/A N/A 2579.98 N/A N/A
googlenet 🚧 214 (15) 140 (5) 61 1773.25 16773.12 99.97
hrnet_w18.ms_aug_in1k 🚧 1488 (14) 705 (8) 281 5566.64 62911.17 99.94
hrnet_w18.ms_aug_in1k-train 4277 (24) N/A N/A 16857.6 N/A N/A
inception_v4.tf_in1k 🚧 495 (11) 341 (5) 150 13749.2 33889.79 99.98
inception_v4.tf_in1k-train 🚧 1702 (24) 1406 (16) 236 37924.7 122028.81 100.0
microsoft/beit-base-patch16-224 🚧 793 (21) 292 (6) 195 15748.7 20194.94 98.95
microsoft/beit-base-patch16-224-train 🚧 2229 (34) 768 (19) 809 76034.4 64984.99 100.0
microsoft/beit-large-patch16-224 🚧 1573 (21) 580 (6) 387 47573.4 44856.53 99.47
microsoft/beit-large-patch16-224-train 🚧 4437 (34) 1524 (19) 1613 491441 128378.66 100.0
mixer_b16_224.goog_in21k 🚧 356 (11) 27 (4) 27 17721.9 13685.3 58.85
mixer_b16_224.goog_in21k-train 🚧 959 (18) 155 (11) 357 56765.7 38446.16 100.0
mobilenet_v2 🚧 154 (9) 139 (3) 11 824.02 10407.55 99.98
mobilenet_v3_large 🚧 188 (11) 137 (4) 44 761.33 14060.74 99.94
mobilenet_v3_small 🚧 158 (11) 114 (4) 39 505.8 16522.52 99.93
mobilenetv1_100.ra4_e3600_r224_in1k 🚧 85 (7) 81 (3) 1 1312.87 6179.8 99.19
mobilenetv1_100.ra4_e3600_r224_in1k-train 🚧 231 (15) 220 (11) 6 4085.98 15887.8 100.0
regnet_x_16gf 🚧 235 (8) 142 (2) 67 17821.5 34501.8 99.96
regnet_x_1_6gf 🚧 195 (8) 118 (2) 55 1742.44 12344.93 99.97
regnet_x_32gf 🚧 245 (8) 148 (2) 70 35194.5 51785.7 99.97
regnet_x_3_2gf 🚧 265 (8) 160 (2) 76 2976.18 14834.75 99.95
regnet_x_400mf 🚧 235 (8) 142 (2) 67 872.66 12990.26 99.96
regnet_x_800mf 🚧 175 (8) 106 (2) 49 1195.19 13594.22 99.96
regnet_x_8gf 🚧 245 (8) 148 (2) 70 7216.56 18505.15 99.95
regnet_y_128gf 🚧 447 (10) 226 (2) 136 516541 549899.82 99.86
regnet_y_16gf 🚧 303 (10) 154 (2) 91 14877.1 44061.93 99.94
regnet_y_1_6gf 🚧 447 (10) 226 (2) 136 2381.61 26895.32 99.9
regnet_y_32gf 🚧 335 (10) 170 (2) 101 32296 57080.89 99.93
regnet_y_3_2gf 🚧 351 (10) 178 (2) 106 3985.46 26390.28 99.93
regnet_y_400mf 🚧 271 (10) 138 (2) 81 867.4 25439.88 99.77
regnet_y_800mf 🚧 239 (10) 122 (2) 71 1425.48 27316.8 99.88
regnet_y_8gf 🚧 287 (10) 146 (2) 86 9803.33 29674.72 99.96
resnet101 🚧 346 (9) 209 (3) 100 7526.94 19056.74 99.96
resnet152 🚧 516 (9) 311 (3) 151 10215.1 36087.0 99.95
resnet18 🚧 70 (9) 41 (3) 17 2087.8 10660.24 99.99
resnet34 🚧 126 (9) 73 (3) 33 3927.67 7546.99 99.99
resnet50 🚧 176 (9) 107 (3) 49 4071.94 12739.83 99.98
resnext101_32x8d 🚧 346 (9) 209 (3) 100 14908.7 27322.34 99.97
resnext101_64x4d 🚧 346 (9) 209 (3) 100 14004.3 26953.24 99.97
resnext50_32x4d 🚧 176 (9) 107 (3) 49 4060.45 9983.94 99.98
retinanet_resnet50_fpn 1107 (32) N/A N/A 2609.91 N/A N/A
retinanet_resnet50_fpn_v2 617 (33) N/A N/A 2187.05 N/A N/A
speecht5-tts 🚧 862 (21) 51 (9) 66 53375.6 94771.96 -0.99
ssd300_vgg16 387 (32) N/A N/A 3246.89 N/A N/A
ssdlite320_mobilenet_v3_large 575 (34) N/A N/A 581.4 N/A N/A
swin_b 🚧 1898 (30) 336 (16) 192 14116.5 87463.01 4.7
swin_s 🚧 1898 (30) 338 (16) 194 10363.3 45186.43 12.53
swin_t 🚧 968 (30) 176 (16) 98 5503.04 57937.4 15.54
swin_v2_b 🚧 2474 (37) 475 (18) 243 17847.8 73365.17 -0.22
swin_v2_s 🚧 2474 (37) 477 (18) 245 10426.6 54625.29 1.73
swin_v2_t 🚧 1256 (37) 249 (18) 119 5733.39 68914.5 10.99
t5-base 14731 (38) N/A N/A 9095.16 N/A N/A
t5-large 22738 (38) N/A N/A 78308.8 N/A N/A
t5-small 6160 (38) N/A N/A 8292.52 N/A N/A
textattack/albert-base-v2-imdb 🚧 782 (22) 42 (6) 63 3066.9 17579.66 100.0
tf_efficientnet_lite0.in1k 🚧 149 (9) 136 (4) 10 1765.16 10359.33 99.98
tf_efficientnet_lite0.in1k-train 🚧 403 (17) 374 (12) 24 2768.62 22706.99 100.0
tf_efficientnet_lite1.in1k 🚧 194 (9) 176 (4) 15 2020.45 13381.87 99.98
tf_efficientnet_lite1.in1k-train 🚧 523 (17) 484 (12) 34 3276.86 24412.53 100.0
tf_efficientnet_lite2.in1k 🚧 194 (9) 176 (4) 15 2060.03 12891.94 99.98
tf_efficientnet_lite2.in1k-train 🚧 523 (17) 484 (12) 34 4588.13 25859.37 100.0
tf_efficientnet_lite3.in1k 🚧 221 (9) 200 (4) 18 2463.76 11809.62 99.98
tf_efficientnet_lite3.in1k-train 🚧 595 (17) 550 (12) 40 7127.07 30883.27 100.0
tf_efficientnet_lite4.in1k 🚧 275 (9) 248 (4) 24 4759.59 16560.78 99.98
tf_efficientnet_lite4.in1k-train 🚧 739 (17) 682 (12) 52 16333.4 59532.86 100.0
twmkn9/albert-base-v2-squad2 🚧 783 (23) 28 (3) 65 3510.35 26854.39 97.58
vgg11 🚧 33 (8) 15 (4) 11 11615.5 20310.32 99.97
vgg11_bn 🚧 41 (9) 23 (5) 11 11707.1 15362.81 99.98
vgg13 🚧 37 (8) 17 (4) 13 19567.2 21567.41 99.98
vgg13_bn 🚧 47 (9) 27 (5) 13 18124.7 22348.68 99.98
vgg16 🚧 43 (8) 20 (4) 16 25346.9 35202.22 99.98
vgg16_bn 🚧 56 (9) 33 (5) 16 23502 27642.26 99.98
vgg19 🚧 49 (8) 23 (4) 19 31330.9 32977.1 99.98
vgg19_bn 🚧 65 (9) 39 (5) 19 31725 36309.45 99.97
vit_b_16 🚧 552 (17) 148 (6) 111 12258.5 29569.52 99.12
vit_b_32 🚧 552 (17) 149 (7) 110 4265.64 15591.78 98.08
vit_h_14 🚧 1452 (17) 164 (6) 163 763516 821420.18 99.29
vit_l_16 🚧 1092 (17) 292 (6) 219 42210.7 59423.71 99.62
vit_l_32 🚧 1092 (17) 124 (6) 123 13802.6 34313.69 98.86
wide_resnet101_2 🚧 346 (9) 209 (3) 100 22240.9 36777.9 99.97
wide_resnet50_2 🚧 176 (9) 107 (3) 49 12236.3 23852.43 99.98
xception71.tf_in1k 🚧 393 (9) 292 (2) 80 18169 46931.77 99.98
xception71.tf_in1k-train 🚧 1370 (18) 1240 (12) 108 61299 134629.6 100.0

Explanation of Metrics

Model: Name of the model.
Run Success: Indicates whether the model runs successfully after conversion.
Torch Ops Before (Unique Ops): The total number of operations used by the model in the original Torch implementation. The number in parenthesis represents the total unique ops.
Torch Ops Remain (Unique Ops): The total number of operations used after conversion to TTNN. The number in parenthesis represents the total unique ops.
To/From Device Ops: The number of to/from_device operations (data transfer to/from the device).
Original Run Time (ms): Execution time (in seconds) of the model before conversion.
Compiled Run Time (ms): Execution time (in seconds) of the model after conversion.
Accuracy (%): Model accuracy on a predefined test dataset after conversion.


Quickstart

The torch_ttnn module has a backend function, which can be used with the torch.compile().

import torch
import torch_ttnn

# A torch Module
class FooModule(torch.nn.Module):
    ...
# Create a module
module = FooModule()

# Compile the module, with ttnn backend
device = ttnn.open_device(device_id=0)
option = torch_ttnn.TorchTtnnOption(device=self.device)
ttnn_module = torch.compile(module, backend=torch_ttnn.backend, options=option)

# Running inference / training
ttnn_module(input_data)

Tracer

The tracer dump the information of fx graph such as node's op_name and shape.

For example, you can run this script to parse the information

PYTHONPATH=$(pwd) python3 tools/stat_models.py --trace_orig --backward --profile
ls stat/raw

By default, the raw result will be stored at stat/raw, and you can run this script to generate the report

python3 tools/generate_report.py
ls stat/

Now the stat/ folder have these report

The node_count.csv show the node with op_type appear in the fx graph. This report can help analyze the frequency of op type appear in the graph.

The *_total_*_size_dist/ statistics the op_type's input/output_size distribution from all fx graph recored in stat/raw. This report can help analyze the memory footprint durning the calculation of op_type.

The profile/ is the tools provided by pytorch, you can open it by the url: chrome://tracing

For developers

Install torch-ttnn with editable mode

During development, you may want to use the torch-ttnn package for testing. In order to do that, you can install the torch-ttnn package in "editable" mode with

pip install -e .

Now, you can utilize torch_ttnn in your Python code. Any modifications you make to the torch_ttnn package will take effect immediately, eliminating the need for constant reinstallation via pip.

Build wheel file

For developers want to deploy the wheel, you can build the wheel file with

python -m build

Then you can upload the .whl file to the PyPI (Python Package Index).

Run transformer models

To run transformer model with ttnn backend, run:

PYTHONPATH="$TT_METAL_HOME:$(pwd)" python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn

You can also substitute the backend with torch_stat to run a reference comparison.

Add a model test

If you want to record run time metrics for a model or test, include a Pytest fixture named record_property as a parameter and set the "model_name" key.
If you also want to compile the model with torch_ttnn backend, set the "torch_ttnn" key to a tuple in this order (model, test_inputs, outputs). "model_name" still needs to be set. See the example code snippet below. torch.nn.Module models with generate method is supported.

def Model(torch.nn.Module):
    def forward(self, x):
        # ...
        return outputs

# Add compilation_xfail marker if torch/CPU runs, but compiled version is xfail
@pytest.mark.compilation_xfail
# Add "record_property" parameter
def test_model_name(record_property):
    # Should be set as early as possible
    record_property("model_name", "Model Name")

    model = Model()
    # ...
    outputs = model(test_input)
    # outputs = model(**test_inputs) # dictionary inputs are also supported
    # ...

    # Can be set once all three objects for the tuple are defined
    record_property("torch_ttnn", (model, test_input(s), outputs))

If model.generate(inputs) is used, pass in model.generate instead of model to record_property.