Closed bbradelTT closed 4 months ago
Ran a grep command in the bmm directory. The result is a good starting point for what needs to change:
/proj_sw/user_dev/bbradel/tt-metal/tt_eager/tt_dnn/op_library/bmm > grep 'shape\[[0-9]' -r * -n
bmm_op.cpp:102: int64_t num_mul_adds_per_elem = in_a_shape[3] * 2; // 1 multiply and 1 add per element
bmm_op.cpp:103: int64_t num_mul_adds = num_mul_adds_per_elem * out_shape[2] * out_shape[3] * out_shape[1] * out_shape[0];
bmm_op.cpp:127: tt::log_info(tt::LogOp, "\t Batch: ({}, {})", out_shape[0], out_shape[1]);
bmm_op.cpp:128: tt::log_info(tt::LogOp, "\t In A (H, W): ({}, {})", in_a_shape[2], in_a_shape[3]);
bmm_op.cpp:129: tt::log_info(tt::LogOp, "\t In B (H, W): ({}, {})", in_b_shape[2], in_b_shape[3]);
bmm_op.cpp:130: tt::log_info(tt::LogOp, "\t Out (H, W): ({}, {})", out_shape[2], out_shape[3]);
bmm_op.cpp:376: in0_block_w = shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:379: per_core_M = shard_shape[0] / TILE_HEIGHT;
bmm_op.cpp:422: virtual_y == (M / (shard_shape[0] / TILE_HEIGHT)), "Num cores along y must match provided grid size!");
bmm_op.cpp:424: virtual_x == (K / (shard_shape[1] / TILE_WIDTH)), "Num cores along x must match provided grid size!");
bmm_op.cpp:428: uint32_t in0_block_w = shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:462: uint32_t per_core_M = in0_shard_shape[0] / TILE_HEIGHT;
bmm_op.cpp:464: uint32_t in0_block_w = in0_shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:743: uint32_t num_output_tiles = ashape[0] * ashape[1] * ashape[2] * bshape[3] / TILE_HW; // Output M x N
bmm_op.cpp:746: uint32_t B = ashape[0] * ashape[1];
bmm_op.cpp:747: uint32_t Mt = ashape[2] / TILE_HEIGHT;
bmm_op.cpp:748: uint32_t Kt = ashape[3] / TILE_WIDTH;
bmm_op.cpp:749: uint32_t Nt = bshape[3] / TILE_WIDTH;
bmm_op.cpp:904: "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op"); // A.K == B.K
bmm_op.cpp:959: TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:961: TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);
bmm_op.cpp:997: TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:999: TT_FATAL(program_config.in0_block_w == (shard_shape[1] / TILE_WIDTH));
bmm_op.cpp:1041: TT_FATAL(K / (shard_shape[1] / TILE_WIDTH) == div_up(N, program_config.per_core_N));
bmm_op.cpp:1044: TT_FATAL(program_config.in0_block_w == (shard_shape[1] / TILE_WIDTH));
bmm_op.cpp:1050: TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:1051: TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);
bmm_op.cpp:1057: TT_FATAL(program_config.per_core_N == (input_tensor_b.shard_spec().value().shape[0] / TILE_WIDTH));
bmm_op.cpp:1090: TT_FATAL(K == in0_shard_shape[1]);
bmm_op.cpp:1091: TT_FATAL(in0_shard_shape[1] == program_config.in0_block_w * TILE_WIDTH);
bmm_op.cpp:1092: TT_FATAL(per_core_M * TILE_HEIGHT == in0_shard_shape[0]);
bmm_op.cpp:1120: TT_FATAL(in1_shard_shape[1] == input_tensor_b.get_legacy_shape()[-1]);
bmm_op.cpp:1121: TT_FATAL(per_core_N * TILE_HEIGHT == in1_shard_shape[1]);
bmm_op.cpp:1122: TT_FATAL(in1_shard_shape[0] % K == 0);
bmm_op.cpp:1547: m_tiles_per_core = shard_shape[0] / ttnn::TILE_SIZE;
bmm_op.cpp:1549: k_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
bmm_op.cpp:1556: n_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
bmm_op.cpp:1595: m_tiles_per_core = shard_shape[0] / ttnn::TILE_SIZE;
bmm_op.cpp:1596: n_tiles_per_core = (n * shard_shape[1]) / (k * ttnn::TILE_SIZE);
bmm_op.cpp:1597: k_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
multi_core/bmm_op_multi_core.cpp:42: auto num_output_tiles_total = cshape[0] * cshape[1] * cshape[2] * cshape[3] / TILE_HW;
multi_core/bmm_op_multi_core.cpp:50: uint32_t B = ashape[0]*ashape[1];
multi_core/bmm_op_multi_core.cpp:51: uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core/bmm_op_multi_core.cpp:52: uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core/bmm_op_multi_core.cpp:53: uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:255: uint32_t B = ashape[0]*ashape[1];
multi_core_reuse/bmm_op_multi_core_reuse.cpp:256: uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:257: uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:258: uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1454: bshape[0] * shape[1] == 1 &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1459: ashape[1] == bshape[1] && ashape[0] == bshape[0] &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1466: ashape[3] == bshape[2] &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1467: "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op"); // A.K == B.K
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1468: TT_FATAL(ashape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1469: TT_FATAL(ashape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1470: TT_FATAL(bshape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1471: TT_FATAL(bshape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1510: uint32_t B = ashape[] * ashape[1];
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1511: uint32_t Mt = ashape[2] / TILE_HEIGHT;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1512: uint32_t Kt = ashape[3] / TILE_WIDTH;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1513: uint32_t Nt = bshape[3] / TILE_WIDTH;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1029: bshape[0] * shape[1] == 1 &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1034: ashape[1] == bshape[1] && ashape[0] == bshape[0] &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1041: ashape[3] == bshape[2] &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1042: "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op"); // A.K == B.K
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1043: TT_FATAL(ashape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1044: TT_FATAL(ashape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1045: TT_FATAL(bshape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1046: TT_FATAL(bshape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1085: uint32_t B = ashape[] * ashape[1];
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1086: uint32_t Mt = ashape[2] / TILE_HEIGHT;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1087: uint32_t Kt = ashape[3] / TILE_WIDTH;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1088: uint32_t Nt = bshape[3] / TILE_WIDTH;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:453: TT_FATAL((bcast_batch == false) or (ashape[0] == 1), "Bcast batch not supported for this parallelization");
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:467: TT_FATAL(bshape[0]*bshape[1] == 1 && "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN");
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:470: TT_FATAL(ashape[1] == bshape[1] && shape[0] == bshape[0]
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:508: uint32_t B = ashape[0]*ashape[1];
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:509: uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:510: uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:511: uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:281: uint32_t B = ashape[0]*ashape[1];
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:282: uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:283: uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:284: uint32_t Nt = bshape[3]/TILE_WIDTH;
The shard_shape is already 2d and should be ok.
The grep did not capture all potential locations.
The following grep/awk command, done after editing the previous set of locations, identified more locations:
grep 'shape.*\[[0-9]\]' -r * -n | awk '{if ($0 !~ /shard/) { print $0 }}' -
bmm_op.cpp:581: auto seq_len = input_tensor_a.get_legacy_shape()[2];
bmm_op.cpp:628: auto seq_len = input_tensor_a.get_legacy_shape()[2];
bmm_op.cpp:1110: (input_tensor_a.get_legacy_shape()[0] * input_tensor_a.get_legacy_shape()[1] > 1 and
bmm_op.cpp:1111: input_tensor_b.get_legacy_shape()[0] * input_tensor_b.get_legacy_shape()[1] == 1);
bmm_op.cpp:1301: // TODO: If input_tensor_a.get_legacy_shape()[0] * input_tensor_a.get_legacy_shape()[1] * ... except last two dimensions == 1, does matmuls work if
multi_core/bmm_op_multi_core.cpp:42: auto num_output_tiles_total = cshape[0] * cshape[1] * cshape[2] * cshape[3] / TILE_HW;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:453: TT_FATAL((bcast_batch == false) or (ashape[0] == 1), "Bcast batch not supported for this parallelization");
single_core/bmm_op_single_core_tilize_untilize.cpp:172: uint32_t in0_batch = in0.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:173: uint32_t in0_channel = in0.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:174: uint32_t in0_height = in0.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:175: uint32_t in0_width = in0.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:176: uint32_t in1_batch = in1.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:177: uint32_t in1_channel = in1.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:178: uint32_t in1_height = in1.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:179: uint32_t in1_width = in1.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:187: TT_FATAL(bias.get_legacy_shape()[3] == in1.get_legacy_shape()[3], "Bias shape mismatch");
single_core/bmm_op_single_core_tilize_untilize.cpp:196: TT_FATAL(bias.get_legacy_shape()[2] % constants::TILE_HEIGHT == 0);
single_core/bmm_op_single_core_tilize_untilize.cpp:197: TT_FATAL(bias.get_legacy_shape()[3] % constants::TILE_WIDTH == 0);
single_core/bmm_op_single_core_tilize_untilize.cpp:329: bias_ntiles_w = bias.get_legacy_shape()[3] / constants::TILE_WIDTH;
single_core/bmm_op_single_core_tilize_untilize.cpp:604: auto in0_batch = in0.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:605: auto in0_channel = in0.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:606: auto in0_height = in0.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:607: auto in0_width = in0.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:608: auto in1_batch = in1.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:609: auto in1_channel = in1.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:610: auto in1_height = in1.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:611: auto in1_width = in1.get_legacy_shape()[3];
The next set of issues are related to the call to bcast in ttnn/cpp/ttnn/operations/matmul.cpp
First, there were multiple index references other than -1 and -2.
Second the tt_metal path uses launch_with_autoformat which required 4D tensors:
tt_eager/tt_dnn/op_library/auto_format.hpp:
static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) {
auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0];
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
which is used by operation::launch_with_autoformat
I switched to operations::primary::bcast
I looked at what torch does for tensors of different ranks. The output tensor rank may also need to be adjusted to take into account ranks for both inputs since torch uses the larger rank, and currently rank is based on input 0 in bmm_op.
I ran the following:
a = [
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([2,1,64,512],dtype=torch.bfloat16),
torch.rand([64,512],dtype=torch.bfloat16),
torch.rand([1,64,512],dtype=torch.bfloat16),
torch.rand([1,1,1,1,64,512],dtype=torch.bfloat16),
torch.rand([1,1,2,1,64,512],dtype=torch.bfloat16),
]
b = [
torch.rand([1,1,512,64],dtype=torch.bfloat16),
torch.rand([1,512,64],dtype=torch.bfloat16),
torch.rand([2,512,64],dtype=torch.bfloat16),
torch.rand([1,2,512,64],dtype=torch.bfloat16),
]
for in1 in a:
for in2 in b:
print(f'matmul in1={in1.size()} in2={in2.size()}')
c = torch.matmul(in1,in2)
print(f'in1={in1.size()} in2={in2.size()} c={c.size()}')
The output is
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([2, 1, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([2, 1, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 2, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([2, 2, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 1, 1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 1, 1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 2, 1, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 2, 1, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 1, 2, 2, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 1, 2, 2, 64, 64])
Tensors beyond rank 4 are currently not allowed:
RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/validation.hpp:40: tt::exception
info:
ttnn.matmul: Tensor rank is not valid: rank is 6 but must be 2 <= rank <- 4
For the reshape, the code should be able to handle different shapes the same way as pytorch with a few exceptions where pytorch takes the cross product of shapes and tt-metal will take the shape of the first input tensor:
shape equal False in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([2, 1, 64, 64]) shape equal False in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([2, 1, 64, 64]) shape equal False in1=torch.Size([1, 2, 64, 512]) in2=torch.Size([2, 1, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([1, 2, 64, 64])
Since tt-metal matmul either
Code:
import torch
import ttnn
import time
device_id = 0
device = ttnn.open_device(device_id=device_id)
ttnn.enable_program_cache(device)
a = [
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([2,1,64,512],dtype=torch.bfloat16),
torch.rand([64,512],dtype=torch.bfloat16),
torch.rand([1,64,512],dtype=torch.bfloat16),
torch.rand([2,64,512],dtype=torch.bfloat16),
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([1,2,64,512],dtype=torch.bfloat16),
]
b = [
torch.rand([1,1,512,64],dtype=torch.bfloat16),
torch.rand([1,512,64],dtype=torch.bfloat16),
torch.rand([512,64],dtype=torch.bfloat16),
torch.rand([2,512,64],dtype=torch.bfloat16),
torch.rand([1,2,512,64],dtype=torch.bfloat16),
torch.rand([2,1,512,64],dtype=torch.bfloat16),
]
def batch(t):
result = 1
for i in range(len(t) - 2):
result *= t[i]
return result
for in1 in a:
t1=ttnn.from_torch(in1, layout=ttnn.TILE_LAYOUT, device=device)
batch_a = batch(in1.size())
for in2 in b:
batch_b = batch(in2.size())
if batch_b > batch_a:
continue
len1 = len(in1.size())
len2 = len(in2.size())
print(f'in1={in1.size()} in2={in2.size()} len1={len1} len2={len2} in2.size()[0]={in2.size()[0]}')
if len2 > len1 and in2.size()[0] > 1:
continue
t2=ttnn.from_torch(in2, layout=ttnn.TILE_LAYOUT, device=device)
print(f'matmul in1={in1.size()} in2={in2.size()}')
c = torch.matmul(in1,in2)
print(f'torch matmul done in1={in1.size()} in2={in2.size()} c={c.size()}')
d = ttnn.matmul(t1, t2)
print(f'ttnn matmul done in1={in1.size()} in2={in2.size()} d={d}')
print(f'shape equal {c.size()==d.shape} in1={in1.size()} in2={in2.size()} c {c.size()} d {d.shape}')
shapes_equal = c.size()==d.shape
ct=ttnn.from_torch(c, layout=ttnn.TILE_LAYOUT, device=device)
if shapes_equal:
dt = d.cpu().to_torch()
diff = ((dt-c)/c).abs()
max_pcc = 1.0 - diff.max()
tensors_equal = max_pcc > 0.7
print(f'tensors equal {tensors_equal} {max_pcc} {diff} ct {ct} d {d}')
ttnn.close_device(device)
Will need to handle nD tensors for matmul.
Tensors of rank 1 should fail since that's not even a matrix.
For other tensors, should assume the product of ranks 0..size(ranks)-2 is the batch size.
Will also need to replace any uses of shapes via indices other than -1 and -2
Also: may clean up code a bit. E.g. make bias required for linear()