This project allows to run PyTorch code on Tenstorrent hardware.
The table below summarizes the results of running various ML models through our TTNN compiler. For each model, we track whether the run was successful, the number of operations before and after conversion, the number of to_device
and from_device
operations, performance metrics, and accuracy.
Model | Run Success | Torch Ops Before (Unique Ops) | Torch Ops Remain (Unique Ops) | To/From Device Ops | Original Run Time (ms) | Compiled Run Time (ms) | Accuracy (%) |
---|---|---|---|---|---|---|---|
[Autoencoder (conv)](<docs/models/Autoencoder (conv)>) | 🚧 | 9 (3) | 6 (2) | 3 | 1337 | 314.93 | 100.0 |
[Autoencoder (conv)-train](<docs/models/Autoencoder (conv)-train>) | 🚧 | 24 (7) | 21 (6) | 3 | 1988.33 | 1321.0 | 100.0 |
[Autoencoder (linear)](<docs/models/Autoencoder (linear)>) | 🚧 | 22 (3) | 1 (1) | 1 | 1325.56 | 812.88 | 100.0 |
[Autoencoder (linear)-train](<docs/models/Autoencoder (linear)-train>) | 🚧 | 104 (8) | 31 (4) | 40 | 1798.33 | 6112.67 | 100.0 |
BERT | ✅ | 1393 (21) | 0 (0) | 28 | 71575.9 | 49045.6 | 98.88 |
Bloom | 🚧 | 1407 (29) | 105 (9) | 78 | 78687.4 | 58629.53 | 41.09 |
CLIP | 🚧 | 1396 (30) | 632 (20) | 357 | 4798.48 | 33033.74 | 94.11 |
CLIP-train | ❌ | 3943 (44) | N/A | N/A | 25963.6 | N/A | N/A |
DETR | ❌ | 1668 (41) | N/A | N/A | 89729 | N/A | N/A |
DPR | ❌ | 720 (22) | N/A | N/A | 5577.48 | N/A | N/A |
FLAN-T5 | ❌ | 20106 (38) | N/A | N/A | 12582.4 | N/A | N/A |
Falcon | 🚧 | 2694 (29) | 583 (17) | 576 | 133002 | 191097.31 | 95.02 |
GLPN-KITTI | ❌ | 3074 (30) | N/A | N/A | 126610 | N/A | N/A |
GPT-2 | ❌ | 748 (31) | N/A | N/A | 7089.46 | N/A | N/A |
GPTNeo | ❌ | 2761 (36) | N/A | N/A | 14691.9 | N/A | N/A |
[Hand Landmark](<docs/models/Hand Landmark>) | ❌ | N/A | N/A | N/A | 6115.49 | N/A | N/A |
HardNet | 🚧 | 245 (10) | 240 (5) | 1 | 4511.35 | 12918.2 | 99.99 |
HardNet-train | 🚧 | 867 (21) | 650 (15) | 71 | 12115.8 | 45096.38 | 100.0 |
Llama | 🚧 | 104 (5) | 33 (2) | 35 | 242747 | 167580.53 | 100.0 |
MLPMixer | ❌ | 253 (11) | N/A | N/A | 5441.92 | N/A | N/A |
MLPMixer-train | ❌ | 616 (19) | N/A | N/A | 17382.3 | N/A | N/A |
Mnist | 🚧 | 14 (8) | 3 (2) | 3 | 3711.39 | 4745.71 | 98.88 |
Mnist-train | 🚧 | 46 (15) | 20 (8) | 14 | 3501.9 | 6950.09 | 100.0 |
MobileNetSSD | ❌ | 575 (34) | N/A | N/A | 1142.43 | N/A | N/A |
MobileNetV2 | 🚧 | 154 (9) | 139 (3) | 11 | 843.1 | 11656.39 | 99.98 |
OPT | ❌ | 4073 (32) | N/A | N/A | 31291.9 | N/A | N/A |
[OpenPose V2](<docs/models/OpenPose V2>) | 🚧 | 155 (7) | 98 (4) | 50 | 2817.68 | 7916.12 | 93.11 |
[OpenPose V2-train](<docs/models/OpenPose V2-train>) | 🚧 | 523 (14) | 450 (12) | 69 | 9860.72 | 27751.17 | 100.0 |
[Perceiver IO](<docs/models/Perceiver IO>) | 🚧 | 1532 (21) | 4 (3) | 30 | 66782 | 64844.16 | 99.94 |
ResNet18 | 🚧 | 70 (9) | 41 (3) | 17 | 2184.85 | 5046.11 | 99.99 |
ResNet18-train | 🚧 | 241 (19) | 197 (14) | 31 | 6564.42 | 19230.47 | 100.0 |
ResNet50 | 🚧 | 176 (9) | 107 (3) | 49 | 4149.54 | 17436.06 | 99.98 |
ResNet50-train | 🚧 | 616 (19) | 524 (14) | 71 | 16070.8 | 42933.82 | 100.0 |
RoBERTa | ❌ | 719 (21) | N/A | N/A | 17300.3 | N/A | N/A |
SegFormer | ❌ | 768 (27) | N/A | N/A | 26060.7 | N/A | N/A |
SegFormer-train | ❌ | 1872 (40) | N/A | N/A | 54648.9 | N/A | N/A |
SqueezeBERT | 🚧 | 16 (9) | 4 (2) | 4 | 7448.21 | 5310.57 | 100.0 |
[Stable Diffusion V2](<docs/models/Stable Diffusion V2>) | ❌ | 1883 (32) | N/A | N/A | 2.09198e+06 | N/A | N/A |
U-Net | 🚧 | 68 (6) | 49 (4) | 19 | 40201.9 | 41711.71 | 100.0 |
U-Net-train | 🚧 | 236 (15) | 205 (12) | 27 | 78688.1 | 90950.49 | 100.0 |
Unet-brain | 🚧 | 68 (6) | 49 (4) | 19 | 40387.7 | 42396.33 | N/A |
Unet-brain-train | 🚧 | 236 (15) | 205 (12) | 27 | 79632.8 | 85364.46 | 100.0 |
Unet-carvana | 🚧 | 67 (5) | 49 (4) | 18 | 62921.6 | 64848.64 | 100.0 |
Unet-carvana-train | 🚧 | 232 (13) | 202 (11) | 26 | 142637 | 152138.14 | 100.0 |
ViLT | 🚧 | 55 (18) | 31 (14) | 8 | 16538.5 | 19866.4 | 88.03 |
Whisper | ❌ | 4294 (19) | N/A | N/A | 243685 | N/A | N/A |
XGLM | 🚧 | 1459 (30) | 94 (18) | 78 | 45533 | 59245.47 | 87.99 |
YOLOS | 🚧 | 966 (28) | 50 (8) | 32 | 20144.5 | 48544.56 | 42.43 |
YOLOv3 | ❌ | 268 (10) | N/A | N/A | 214728 | N/A | N/A |
YOLOv5 | 🚧 | 3 (3) | 2 (2) | 1 | 31609.8 | 29510.87 | 100.0 |
albert/albert-base-v2 | 🚧 | 791 (21) | 41 (5) | 62 | 3401.37 | 32049.73 | 66.33 |
albert/albert-base-v2-classification | 🚧 | 779 (21) | 28 (3) | 63 | 3249.45 | 17310.19 | -99.53 |
albert/albert-large-v2 | 🚧 | 1547 (21) | 101 (5) | 170 | 5327.23 | 39768.73 | 17.62 |
albert/albert-xlarge-v2 | 🚧 | 1547 (21) | 77 (5) | 122 | 16889.7 | 42096.27 | 51.12 |
albert/albert-xxlarge-v2 | 🚧 | 791 (21) | 65 (6) | 39 | 42072 | 38159.16 | 15.46 |
codegen | ❌ | 9237 (37) | N/A | N/A | 11740.1 | N/A | N/A |
densenet121 | 🚧 | 432 (10) | 307 (5) | 121 | 2823.78 | 20091.98 | 99.99 |
densenet161 | 🚧 | 572 (10) | 407 (5) | 161 | 7373.44 | 38499.01 | 99.99 |
densenet169 | 🚧 | 600 (10) | 427 (5) | 169 | 3310.82 | 26175.12 | 99.99 |
densenet201 | 🚧 | 712 (10) | 507 (5) | 201 | 4847.3 | 55142.83 | 99.99 |
distilbert-base-uncased | ❌ | 367 (17) | N/A | N/A | 8361.94 | N/A | N/A |
dla34.in1k | 🚧 | 135 (9) | 85 (4) | 34 | 6069.93 | 8262.23 | 100.0 |
dla34.in1k-train | 🚧 | 469 (18) | 378 (13) | 57 | 12031.7 | 32889.43 | 100.0 |
ese_vovnet19b_dw.ra_in1k | 🚧 | 111 (12) | 75 (5) | 27 | 2065.5 | 16223.18 | 99.98 |
ese_vovnet19b_dw.ra_in1k-train | 🚧 | 360 (25) | 277 (16) | 60 | 5283.96 | 25003.03 | 100.0 |
facebook/deit-base-patch16-224 | 🚧 | 686 (18) | 257 (6) | 195 | 20954 | 21980.32 | 97.18 |
facebook/deit-base-patch16-224-train | 🚧 | 1856 (28) | 684 (15) | 658 | 72669.1 | 50721.55 | 100.0 |
ghostnet_100.in1k | 🚧 | 515 (14) | 262 (5) | 72 | 752.2 | 23403.22 | 99.97 |
ghostnet_100.in1k-train | 🚧 | 1468 (33) | 1044 (24) | 190 | 1627.17 | 62035.8 | 100.0 |
ghostnetv2_100.in1k | 🚧 | 809 (20) | 394 (12) | 95 | 1615.6 | 37644.62 | 99.93 |
ghostnetv2_100.in1k-train | ❌ | 2126 (41) | N/A | N/A | 2579.98 | N/A | N/A |
googlenet | 🚧 | 214 (15) | 140 (5) | 61 | 1773.25 | 16773.12 | 99.97 |
hrnet_w18.ms_aug_in1k | 🚧 | 1488 (14) | 705 (8) | 281 | 5566.64 | 62911.17 | 99.94 |
hrnet_w18.ms_aug_in1k-train | ❌ | 4277 (24) | N/A | N/A | 16857.6 | N/A | N/A |
inception_v4.tf_in1k | 🚧 | 495 (11) | 341 (5) | 150 | 13749.2 | 33889.79 | 99.98 |
inception_v4.tf_in1k-train | 🚧 | 1702 (24) | 1406 (16) | 236 | 37924.7 | 122028.81 | 100.0 |
microsoft/beit-base-patch16-224 | 🚧 | 793 (21) | 292 (6) | 195 | 15748.7 | 20194.94 | 98.95 |
microsoft/beit-base-patch16-224-train | 🚧 | 2229 (34) | 768 (19) | 809 | 76034.4 | 64984.99 | 100.0 |
microsoft/beit-large-patch16-224 | 🚧 | 1573 (21) | 580 (6) | 387 | 47573.4 | 44856.53 | 99.47 |
microsoft/beit-large-patch16-224-train | 🚧 | 4437 (34) | 1524 (19) | 1613 | 491441 | 128378.66 | 100.0 |
mixer_b16_224.goog_in21k | 🚧 | 356 (11) | 27 (4) | 27 | 17721.9 | 13685.3 | 58.85 |
mixer_b16_224.goog_in21k-train | 🚧 | 959 (18) | 155 (11) | 357 | 56765.7 | 38446.16 | 100.0 |
mobilenet_v2 | 🚧 | 154 (9) | 139 (3) | 11 | 824.02 | 10407.55 | 99.98 |
mobilenet_v3_large | 🚧 | 188 (11) | 137 (4) | 44 | 761.33 | 14060.74 | 99.94 |
mobilenet_v3_small | 🚧 | 158 (11) | 114 (4) | 39 | 505.8 | 16522.52 | 99.93 |
mobilenetv1_100.ra4_e3600_r224_in1k | 🚧 | 85 (7) | 81 (3) | 1 | 1312.87 | 6179.8 | 99.19 |
mobilenetv1_100.ra4_e3600_r224_in1k-train | 🚧 | 231 (15) | 220 (11) | 6 | 4085.98 | 15887.8 | 100.0 |
regnet_x_16gf | 🚧 | 235 (8) | 142 (2) | 67 | 17821.5 | 34501.8 | 99.96 |
regnet_x_1_6gf | 🚧 | 195 (8) | 118 (2) | 55 | 1742.44 | 12344.93 | 99.97 |
regnet_x_32gf | 🚧 | 245 (8) | 148 (2) | 70 | 35194.5 | 51785.7 | 99.97 |
regnet_x_3_2gf | 🚧 | 265 (8) | 160 (2) | 76 | 2976.18 | 14834.75 | 99.95 |
regnet_x_400mf | 🚧 | 235 (8) | 142 (2) | 67 | 872.66 | 12990.26 | 99.96 |
regnet_x_800mf | 🚧 | 175 (8) | 106 (2) | 49 | 1195.19 | 13594.22 | 99.96 |
regnet_x_8gf | 🚧 | 245 (8) | 148 (2) | 70 | 7216.56 | 18505.15 | 99.95 |
regnet_y_128gf | 🚧 | 447 (10) | 226 (2) | 136 | 516541 | 549899.82 | 99.86 |
regnet_y_16gf | 🚧 | 303 (10) | 154 (2) | 91 | 14877.1 | 44061.93 | 99.94 |
regnet_y_1_6gf | 🚧 | 447 (10) | 226 (2) | 136 | 2381.61 | 26895.32 | 99.9 |
regnet_y_32gf | 🚧 | 335 (10) | 170 (2) | 101 | 32296 | 57080.89 | 99.93 |
regnet_y_3_2gf | 🚧 | 351 (10) | 178 (2) | 106 | 3985.46 | 26390.28 | 99.93 |
regnet_y_400mf | 🚧 | 271 (10) | 138 (2) | 81 | 867.4 | 25439.88 | 99.77 |
regnet_y_800mf | 🚧 | 239 (10) | 122 (2) | 71 | 1425.48 | 27316.8 | 99.88 |
regnet_y_8gf | 🚧 | 287 (10) | 146 (2) | 86 | 9803.33 | 29674.72 | 99.96 |
resnet101 | 🚧 | 346 (9) | 209 (3) | 100 | 7526.94 | 19056.74 | 99.96 |
resnet152 | 🚧 | 516 (9) | 311 (3) | 151 | 10215.1 | 36087.0 | 99.95 |
resnet18 | 🚧 | 70 (9) | 41 (3) | 17 | 2087.8 | 10660.24 | 99.99 |
resnet34 | 🚧 | 126 (9) | 73 (3) | 33 | 3927.67 | 7546.99 | 99.99 |
resnet50 | 🚧 | 176 (9) | 107 (3) | 49 | 4071.94 | 12739.83 | 99.98 |
resnext101_32x8d | 🚧 | 346 (9) | 209 (3) | 100 | 14908.7 | 27322.34 | 99.97 |
resnext101_64x4d | 🚧 | 346 (9) | 209 (3) | 100 | 14004.3 | 26953.24 | 99.97 |
resnext50_32x4d | 🚧 | 176 (9) | 107 (3) | 49 | 4060.45 | 9983.94 | 99.98 |
retinanet_resnet50_fpn | ❌ | 1107 (32) | N/A | N/A | 2609.91 | N/A | N/A |
retinanet_resnet50_fpn_v2 | ❌ | 617 (33) | N/A | N/A | 2187.05 | N/A | N/A |
speecht5-tts | 🚧 | 862 (21) | 51 (9) | 66 | 53375.6 | 94771.96 | -0.99 |
ssd300_vgg16 | ❌ | 387 (32) | N/A | N/A | 3246.89 | N/A | N/A |
ssdlite320_mobilenet_v3_large | ❌ | 575 (34) | N/A | N/A | 581.4 | N/A | N/A |
swin_b | 🚧 | 1898 (30) | 336 (16) | 192 | 14116.5 | 87463.01 | 4.7 |
swin_s | 🚧 | 1898 (30) | 338 (16) | 194 | 10363.3 | 45186.43 | 12.53 |
swin_t | 🚧 | 968 (30) | 176 (16) | 98 | 5503.04 | 57937.4 | 15.54 |
swin_v2_b | 🚧 | 2474 (37) | 475 (18) | 243 | 17847.8 | 73365.17 | -0.22 |
swin_v2_s | 🚧 | 2474 (37) | 477 (18) | 245 | 10426.6 | 54625.29 | 1.73 |
swin_v2_t | 🚧 | 1256 (37) | 249 (18) | 119 | 5733.39 | 68914.5 | 10.99 |
t5-base | ❌ | 14731 (38) | N/A | N/A | 9095.16 | N/A | N/A |
t5-large | ❌ | 22738 (38) | N/A | N/A | 78308.8 | N/A | N/A |
t5-small | ❌ | 6160 (38) | N/A | N/A | 8292.52 | N/A | N/A |
textattack/albert-base-v2-imdb | 🚧 | 782 (22) | 42 (6) | 63 | 3066.9 | 17579.66 | 100.0 |
tf_efficientnet_lite0.in1k | 🚧 | 149 (9) | 136 (4) | 10 | 1765.16 | 10359.33 | 99.98 |
tf_efficientnet_lite0.in1k-train | 🚧 | 403 (17) | 374 (12) | 24 | 2768.62 | 22706.99 | 100.0 |
tf_efficientnet_lite1.in1k | 🚧 | 194 (9) | 176 (4) | 15 | 2020.45 | 13381.87 | 99.98 |
tf_efficientnet_lite1.in1k-train | 🚧 | 523 (17) | 484 (12) | 34 | 3276.86 | 24412.53 | 100.0 |
tf_efficientnet_lite2.in1k | 🚧 | 194 (9) | 176 (4) | 15 | 2060.03 | 12891.94 | 99.98 |
tf_efficientnet_lite2.in1k-train | 🚧 | 523 (17) | 484 (12) | 34 | 4588.13 | 25859.37 | 100.0 |
tf_efficientnet_lite3.in1k | 🚧 | 221 (9) | 200 (4) | 18 | 2463.76 | 11809.62 | 99.98 |
tf_efficientnet_lite3.in1k-train | 🚧 | 595 (17) | 550 (12) | 40 | 7127.07 | 30883.27 | 100.0 |
tf_efficientnet_lite4.in1k | 🚧 | 275 (9) | 248 (4) | 24 | 4759.59 | 16560.78 | 99.98 |
tf_efficientnet_lite4.in1k-train | 🚧 | 739 (17) | 682 (12) | 52 | 16333.4 | 59532.86 | 100.0 |
twmkn9/albert-base-v2-squad2 | 🚧 | 783 (23) | 28 (3) | 65 | 3510.35 | 26854.39 | 97.58 |
vgg11 | 🚧 | 33 (8) | 15 (4) | 11 | 11615.5 | 20310.32 | 99.97 |
vgg11_bn | 🚧 | 41 (9) | 23 (5) | 11 | 11707.1 | 15362.81 | 99.98 |
vgg13 | 🚧 | 37 (8) | 17 (4) | 13 | 19567.2 | 21567.41 | 99.98 |
vgg13_bn | 🚧 | 47 (9) | 27 (5) | 13 | 18124.7 | 22348.68 | 99.98 |
vgg16 | 🚧 | 43 (8) | 20 (4) | 16 | 25346.9 | 35202.22 | 99.98 |
vgg16_bn | 🚧 | 56 (9) | 33 (5) | 16 | 23502 | 27642.26 | 99.98 |
vgg19 | 🚧 | 49 (8) | 23 (4) | 19 | 31330.9 | 32977.1 | 99.98 |
vgg19_bn | 🚧 | 65 (9) | 39 (5) | 19 | 31725 | 36309.45 | 99.97 |
vit_b_16 | 🚧 | 552 (17) | 148 (6) | 111 | 12258.5 | 29569.52 | 99.12 |
vit_b_32 | 🚧 | 552 (17) | 149 (7) | 110 | 4265.64 | 15591.78 | 98.08 |
vit_h_14 | 🚧 | 1452 (17) | 164 (6) | 163 | 763516 | 821420.18 | 99.29 |
vit_l_16 | 🚧 | 1092 (17) | 292 (6) | 219 | 42210.7 | 59423.71 | 99.62 |
vit_l_32 | 🚧 | 1092 (17) | 124 (6) | 123 | 13802.6 | 34313.69 | 98.86 |
wide_resnet101_2 | 🚧 | 346 (9) | 209 (3) | 100 | 22240.9 | 36777.9 | 99.97 |
wide_resnet50_2 | 🚧 | 176 (9) | 107 (3) | 49 | 12236.3 | 23852.43 | 99.98 |
xception71.tf_in1k | 🚧 | 393 (9) | 292 (2) | 80 | 18169 | 46931.77 | 99.98 |
xception71.tf_in1k-train | 🚧 | 1370 (18) | 1240 (12) | 108 | 61299 | 134629.6 | 100.0 |
Model: Name of the model.
Run Success: Indicates whether the model runs successfully after conversion.
Torch Ops Before (Unique Ops): The total number of operations used by the model in the original Torch implementation. The number in parenthesis represents the total unique ops.
Torch Ops Remain (Unique Ops): The total number of operations used after conversion to TTNN. The number in parenthesis represents the total unique ops.
To/From Device Ops: The number of to/from_device
operations (data transfer to/from the device).
Original Run Time (ms): Execution time (in seconds) of the model before conversion.
Compiled Run Time (ms): Execution time (in seconds) of the model after conversion.
Accuracy (%): Model accuracy on a predefined test dataset after conversion.
The torch_ttnn
module has a backend
function, which can be used with the torch.compile()
.
import torch
import torch_ttnn
# A torch Module
class FooModule(torch.nn.Module):
...
# Create a module
module = FooModule()
# Compile the module, with ttnn backend
device = ttnn.open_device(device_id=0)
option = torch_ttnn.TorchTtnnOption(device=self.device)
ttnn_module = torch.compile(module, backend=torch_ttnn.backend, options=option)
# Running inference / training
ttnn_module(input_data)
The tracer dump the information of fx graph such as node's op_name and shape.
For example, you can run this script to parse the information
PYTHONPATH=$(pwd) python3 tools/stat_models.py --trace_orig --backward --profile
ls stat/raw
By default, the raw result will be stored at stat/raw
, and you can run this script to generate the report
python3 tools/generate_report.py
ls stat/
Now the stat/
folder have these report
fw_node_count.csv
bw_node_count.csv
fw_total_input_size_dist/
bw_total_input_size_dist/
fw_total_output_size_dist/
bw_total_output_size_dist/
profile/
The node_count.csv
show the node with op_type
appear in the fx graph. This report can help analyze the frequency of op type appear in the graph.
The *_total_*_size_dist/
statistics the op_type
's input/output_size distribution from all fx graph recored in stat/raw
. This report can help analyze the memory footprint durning the calculation of op_type
.
Notice: the default input_shapes
in tools/stat_torchvision.py
is [1,3,224,224]
, which has dependency with *_total_*_size_dist/
report.
Notice: the aten ir interface is in there
The profile/
is the tools provided by pytorch, you can open it by the url: chrome://tracing
During development, you may want to use the torch-ttnn package for testing. In order to do that, you can install the torch-ttnn package in "editable" mode with
pip install -e .
Now, you can utilize torch_ttnn
in your Python code. Any modifications you make to the torch_ttnn
package will take effect immediately, eliminating the need for constant reinstallation via pip.
For developers want to deploy the wheel, you can build the wheel file with
python -m build
Then you can upload the .whl
file to the PyPI (Python Package Index).
To run transformer model with ttnn backend, run:
PYTHONPATH="$TT_METAL_HOME:$(pwd)" python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
You can also substitute the backend with torch_stat
to run a reference comparison.
If you want to record run time metrics for a model or test, include a Pytest fixture named record_property
as a parameter and set the "model_name" key.
If you also want to compile the model with torch_ttnn backend, set the "torch_ttnn" key to a tuple in this order (model, test_inputs, outputs)
. "model_name" still needs to be set. See the example code snippet below. torch.nn.Module
models with generate
method is supported.
def Model(torch.nn.Module):
def forward(self, x):
# ...
return outputs
# Add compilation_xfail marker if torch/CPU runs, but compiled version is xfail
@pytest.mark.compilation_xfail
# Add "record_property" parameter
def test_model_name(record_property):
# Should be set as early as possible
record_property("model_name", "Model Name")
model = Model()
# ...
outputs = model(test_input)
# outputs = model(**test_inputs) # dictionary inputs are also supported
# ...
# Can be set once all three objects for the tuple are defined
record_property("torch_ttnn", (model, test_input(s), outputs))
If model.generate(inputs)
is used, pass in model.generate
instead of model
to record_property
.