nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
46 stars 23 forks source link

matmul-elementwise bf16 model failed compilation #372

Open yzhang93 opened 1 month ago

yzhang93 commented 1 month ago

Input IR

!lhs = tensor<1024x512xbf16>
!rhs = tensor<512x1024xbf16>
!ele = tensor<1024x1024xf32>
!res = tensor<1024x1024xbf16>

func.func @matmul_elementwise_bf16(%lhs : !lhs, %rhs : !rhs, %ele : !ele) -> !res {
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty() : !ele
  %1 = tensor.empty() : !res
  %fill = linalg.fill ins(%cst : f32) outs(%0 : !ele) -> !ele
  %2 = linalg.matmul ins(%lhs, %rhs : !lhs, !rhs) outs(%fill : !ele) -> !ele
  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2, %ele : !ele, !ele) outs(%1 : !res) {
  ^bb0(%in: f32, %in_0: f32, %out: bf16):
    %11 = arith.addf %in, %in_0 : f32
    %12 = arith.truncf %11 : f32 to bf16
    linalg.yield %12 : bf16
  } -> !res
  return %res : !res
}

Error:

LLVM ERROR: unable to legalize instruction: %1730:_(<1024 x s16>) = G_SHUFFLE_VECTOR %1729:_(<1024 x s16>), %1475:_, shufflemask(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1024, 1025, 1026, 1027) (in function: core_0_2)
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc /proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.opt.ll -O2 --march=aie2 --function-sections --filetype=obj -o /proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.o
1.  Running pass 'Function Pass Manager' on module '/proj/xsjhdstaff4/vivizhan/iree-amd-aie/build_tools/ci/cpu_comparison/test_result_bf16/module_matmul_elementwise_bf16_dispatch_0_amdaie_xclbin_fb/input.opt.ll'.
2.  Running pass 'Legalizer' on function '@core_0_2'
 #0 0x000055ae9b6ceebf llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Unix/Signals.inc:567:22
 #1 0x000055ae9b6ccfc4 llvm::sys::RunSignalHandlers() /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Signals.cpp:104:20
 #2 0x000055ae9b6cd146 SignalHandler(int) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/Support/Unix/Signals.inc:412:1
 #3 0x00007fa6da842520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007fa6da8969fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x00007fa6da8969fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #6 0x00007fa6da8969fc pthread_kill ./nptl/pthread_kill.c:89:10
 #7 0x00007fa6da842476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007fa6da8287f3 abort ./stdlib/abort.c:81:7
 #9 0x000055ae9b6438d3 (/proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc+0x2cd98d3)
#10 0x000055ae9bb25532 reportGISelDiagnostic(llvm::DiagnosticSeverity, llvm::MachineFunction&, llvm::TargetPassConfig const&, llvm::MachineOptimizationRemarkEmitter&, llvm::MachineOptimizationRemarkMissed&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Utils.cpp:257:23
#11 0x000055ae9bb26f5b llvm::DiagnosticInfoOptimizationBase::~DiagnosticInfoOptimizationBase() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/IR/DiagnosticInfo.h:413:7
#12 0x000055ae9bb26f5b llvm::DiagnosticInfoMIROptimization::~DiagnosticInfoMIROptimization() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/CodeGen/MachineOptimizationRemarkEmitter.h:30:7
#13 0x000055ae9bb26f5b llvm::MachineOptimizationRemarkMissed::~MachineOptimizationRemarkMissed() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/CodeGen/MachineOptimizationRemarkEmitter.h:84:7
#14 0x000055ae9bb26f5b llvm::reportGISelFailure(llvm::MachineFunction&, llvm::TargetPassConfig const&, llvm::MachineOptimizationRemarkEmitter&, char const*, llvm::StringRef, llvm::MachineInstr const&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Utils.cpp:286:1
#15 0x000055ae9babdb82 llvm::Legalizer::runOnMachineFunction(llvm::MachineFunction&) (.part.0) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/GlobalISel/Legalizer.cpp:348:12
#16 0x000055ae9a7f9b3b llvm::MachineFunctionPass::runOnFunction(llvm::Function&) (.part.0) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/CodeGen/MachineFunctionPass.cpp:91:33
#17 0x000055ae9ad2eaec llvm::FPPassManager::runOnFunction(llvm::Function&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1440:7
#18 0x000055ae9ad2ed19 llvm::ilist_node_base<true>::getNext() const /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_node_base.h:43:45
#19 0x000055ae9ad2ed19 llvm::ilist_node_impl<llvm::ilist_detail::node_options<llvm::Function, true, false, void>>::getNext() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_node.h:67:66
#20 0x000055ae9ad2ed19 llvm::ilist_iterator<llvm::ilist_detail::node_options<llvm::Function, true, false, void>, false, false>::operator++() /proj/rdi/staff/vivizhan/llvm-aie/llvm/include/llvm/ADT/ilist_iterator.h:157:25
#21 0x000055ae9ad2ed19 llvm::FPPassManager::runOnModule(llvm::Module&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1475:22
#22 0x000055ae9ad2f59e runOnModule /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:1552:7
#23 0x000055ae9ad2f59e llvm::legacy::PassManagerImpl::run(llvm::Module&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/lib/IR/LegacyPassManager.cpp:535:55
#24 0x000055ae99e4601e compileModule(char**, llvm::LLVMContext&) /proj/rdi/staff/vivizhan/llvm-aie/llvm/tools/llc/llc.cpp:736:66
#25 0x000055ae99e46f86 main /proj/rdi/staff/vivizhan/llvm-aie/llvm/tools/llc/llc.cpp:420:35
#26 0x00007fa6da829d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#27 0x00007fa6da829e40 call_init ./csu/../csu/libc-start.c:128:20
#28 0x00007fa6da829e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#29 0x000055ae99e3a2e5 _start (/proj/xsjhdstaff4/vivizhan/llvm-aie/install/bin/llc+0x14d02e5)
yzhang93 commented 1 month ago

In contrast, bf16-f32 model (without arith.truncf %11 : f32 to bf16) as below doesn't have such error.

!lhs = tensor<1024x512xbf16>
!rhs = tensor<512x1024xbf16>
!ele = tensor<1024x1024xf32>
!res = tensor<1024x1024xf32>

func.func @matmul_elementwise_bf16(%lhs : !lhs, %rhs : !rhs, %ele : !ele) -> !res {
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty() : !ele
  %fill = linalg.fill ins(%cst : f32) outs(%0 : !ele) -> !ele
  %2 = linalg.matmul ins(%lhs, %rhs : !lhs, !rhs) outs(%fill : !ele) -> !ele
  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2, %ele : !ele, !ele) outs(%0 : !ele) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %11 = arith.addf %in, %in_0 : f32
    linalg.yield %11 : f32
  } -> !res
  return %res : !res
}

@MaheshRavishankar @stephenneuendorffer @newling @erwei-xilinx Any insight about the issue?

MaheshRavishankar commented 1 month ago

I dont know if Peano handles bf16 natively.

stephenneuendorffer commented 1 month ago

I believe there's work going on to implement shuffle_vector. currently the assumption is that the vector ops always go through intrinsics. FYI, for Peano issues, you're better off capturing the .ll code and creating an issue in the peano repo.

gbossu commented 1 month ago

Peano does support bf16 types, and there is indeed work to support more and more cases of generic shuffle_vector. However, I think the problem here is rather that %1730:_(<1024 x s16>) is a huge vector, and we do not have the capability yet to properly legalize those. As Stephen said, it would be very useful if you could get us a small .ll reproducer, then we can investigate what's really happening here :)

ValentijnvdBeek commented 1 month ago

Support for G_SHUFFLE_VECTOR for Peano is soon under review, so that should land soonish. The failing instruction asks for 16-bit so it is not the support for bf in any case. There are two problems with the code as is:

yzhang93 commented 1 month ago

Thanks @stephenneuendorffer @gbossu @ValentijnvdBeek for looking into the issue! Here are the .ll files generated from the above example. Please let me know if you need me to provide other sources. input_ll.zip