tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
430 stars 59 forks source link

ttnn.matmul - allow ND tensors to remove reshape #8112

Closed bbradelTT closed 4 months ago

bbradelTT commented 5 months ago

Will need to handle nD tensors for matmul.

Tensors of rank 1 should fail since that's not even a matrix.

For other tensors, should assume the product of ranks 0..size(ranks)-2 is the batch size.

Will also need to replace any uses of shapes via indices other than -1 and -2

Also: may clean up code a bit. E.g. make bias required for linear()

bbradelTT commented 4 months ago

Ran a grep command in the bmm directory. The result is a good starting point for what needs to change:

/proj_sw/user_dev/bbradel/tt-metal/tt_eager/tt_dnn/op_library/bmm > grep 'shape\[[0-9]' -r * -n
bmm_op.cpp:102:    int64_t num_mul_adds_per_elem = in_a_shape[3] * 2;  // 1 multiply and 1 add per element
bmm_op.cpp:103:    int64_t num_mul_adds = num_mul_adds_per_elem * out_shape[2] * out_shape[3] * out_shape[1] * out_shape[0];
bmm_op.cpp:127:    tt::log_info(tt::LogOp, "\t Batch: ({}, {})", out_shape[0], out_shape[1]);
bmm_op.cpp:128:    tt::log_info(tt::LogOp, "\t In A (H, W): ({}, {})", in_a_shape[2], in_a_shape[3]);
bmm_op.cpp:129:    tt::log_info(tt::LogOp, "\t In B (H, W): ({}, {})", in_b_shape[2], in_b_shape[3]);
bmm_op.cpp:130:    tt::log_info(tt::LogOp, "\t Out (H, W): ({}, {})", out_shape[2], out_shape[3]);
bmm_op.cpp:376:                in0_block_w = shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:379:                per_core_M = shard_shape[0] / TILE_HEIGHT;
bmm_op.cpp:422:                virtual_y == (M / (shard_shape[0] / TILE_HEIGHT)), "Num cores along y must match provided grid size!");
bmm_op.cpp:424:                virtual_x == (K / (shard_shape[1] / TILE_WIDTH)), "Num cores along x must match provided grid size!");
bmm_op.cpp:428:            uint32_t in0_block_w = shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:462:        uint32_t per_core_M = in0_shard_shape[0] / TILE_HEIGHT;
bmm_op.cpp:464:        uint32_t in0_block_w = in0_shard_shape[1] / TILE_WIDTH;
bmm_op.cpp:743:    uint32_t num_output_tiles = ashape[0] * ashape[1] * ashape[2] * bshape[3] / TILE_HW;  // Output M x N
bmm_op.cpp:746:    uint32_t B = ashape[0] * ashape[1];
bmm_op.cpp:747:    uint32_t Mt = ashape[2] / TILE_HEIGHT;
bmm_op.cpp:748:    uint32_t Kt = ashape[3] / TILE_WIDTH;
bmm_op.cpp:749:    uint32_t Nt = bshape[3] / TILE_WIDTH;
bmm_op.cpp:904:        "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op");  // A.K == B.K
bmm_op.cpp:959:                        TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:961:                        TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);
bmm_op.cpp:997:                        TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:999:                        TT_FATAL(program_config.in0_block_w == (shard_shape[1] / TILE_WIDTH));
bmm_op.cpp:1041:                        TT_FATAL(K / (shard_shape[1] / TILE_WIDTH) == div_up(N, program_config.per_core_N));
bmm_op.cpp:1044:                        TT_FATAL(program_config.in0_block_w == (shard_shape[1] / TILE_WIDTH));
bmm_op.cpp:1050:                    TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT));
bmm_op.cpp:1051:                    TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);
bmm_op.cpp:1057:                    TT_FATAL(program_config.per_core_N == (input_tensor_b.shard_spec().value().shape[0] / TILE_WIDTH));
bmm_op.cpp:1090:                    TT_FATAL(K == in0_shard_shape[1]);
bmm_op.cpp:1091:                    TT_FATAL(in0_shard_shape[1] == program_config.in0_block_w * TILE_WIDTH);
bmm_op.cpp:1092:                    TT_FATAL(per_core_M * TILE_HEIGHT == in0_shard_shape[0]);
bmm_op.cpp:1120:                    TT_FATAL(in1_shard_shape[1] == input_tensor_b.get_legacy_shape()[-1]);
bmm_op.cpp:1121:                    TT_FATAL(per_core_N * TILE_HEIGHT == in1_shard_shape[1]);
bmm_op.cpp:1122:                    TT_FATAL(in1_shard_shape[0] % K == 0);
bmm_op.cpp:1547:            m_tiles_per_core = shard_shape[0] / ttnn::TILE_SIZE;
bmm_op.cpp:1549:            k_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
bmm_op.cpp:1556:            n_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
bmm_op.cpp:1595:        m_tiles_per_core = shard_shape[0] / ttnn::TILE_SIZE;
bmm_op.cpp:1596:        n_tiles_per_core = (n * shard_shape[1]) / (k * ttnn::TILE_SIZE);
bmm_op.cpp:1597:        k_tiles_per_core = shard_shape[1] / ttnn::TILE_SIZE;
multi_core/bmm_op_multi_core.cpp:42:    auto num_output_tiles_total = cshape[0] * cshape[1] * cshape[2] * cshape[3] / TILE_HW;
multi_core/bmm_op_multi_core.cpp:50:    uint32_t B = ashape[0]*ashape[1];
multi_core/bmm_op_multi_core.cpp:51:    uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core/bmm_op_multi_core.cpp:52:    uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core/bmm_op_multi_core.cpp:53:    uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:255:    uint32_t B = ashape[0]*ashape[1];
multi_core_reuse/bmm_op_multi_core_reuse.cpp:256:    uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:257:    uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse/bmm_op_multi_core_reuse.cpp:258:    uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1454:            bshape[0] * shape[1] == 1 &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1459:            ashape[1] == bshape[1] && ashape[0] == bshape[0] &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1466:        ashape[3] == bshape[2] &&
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1467:        "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op");  // A.K == B.K
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1468:    TT_FATAL(ashape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1469:    TT_FATAL(ashape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1470:    TT_FATAL(bshape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1471:    TT_FATAL(bshape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1510:    uint32_t B = ashape[] * ashape[1];
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1511:    uint32_t Mt = ashape[2] / TILE_HEIGHT;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1512:    uint32_t Kt = ashape[3] / TILE_WIDTH;
multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp:1513:    uint32_t Nt = bshape[3] / TILE_WIDTH;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1029:            bshape[0] * shape[1] == 1 &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1034:            ashape[1] == bshape[1] && ashape[0] == bshape[0] &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1041:        ashape[3] == bshape[2] &&
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1042:        "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op");  // A.K == B.K
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1043:    TT_FATAL(ashape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1044:    TT_FATAL(ashape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1045:    TT_FATAL(bshape[2] % TILE_HEIGHT == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1046:    TT_FATAL(bshape[3] % TILE_WIDTH == 0);
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1085:    uint32_t B = ashape[] * ashape[1];
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1086:    uint32_t Mt = ashape[2] / TILE_HEIGHT;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1087:    uint32_t Kt = ashape[3] / TILE_WIDTH;
multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp:1088:    uint32_t Nt = bshape[3] / TILE_WIDTH;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:453:    TT_FATAL((bcast_batch == false) or (ashape[0] == 1), "Bcast batch not supported for this parallelization");
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:467:        TT_FATAL(bshape[0]*bshape[1] == 1 && "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN");
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:470:        TT_FATAL(ashape[1] == bshape[1] && shape[0] == bshape[0]
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:508:    uint32_t B = ashape[0]*ashape[1];
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:509:    uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:510:    uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:511:    uint32_t Nt = bshape[3]/TILE_WIDTH;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:281:    uint32_t B = ashape[0]*ashape[1];
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:282:    uint32_t Mt = ashape[2]/TILE_HEIGHT;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:283:    uint32_t Kt = ashape[3]/TILE_WIDTH;
multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp:284:    uint32_t Nt = bshape[3]/TILE_WIDTH;
bbradelTT commented 4 months ago

The shard_shape is already 2d and should be ok.

The grep did not capture all potential locations.

The following grep/awk command, done after editing the previous set of locations, identified more locations:

grep 'shape.*\[[0-9]\]' -r * -n | awk '{if ($0 !~ /shard/) { print $0 }}' -
bmm_op.cpp:581:    auto seq_len = input_tensor_a.get_legacy_shape()[2];
bmm_op.cpp:628:    auto seq_len = input_tensor_a.get_legacy_shape()[2];
bmm_op.cpp:1110:                    (input_tensor_a.get_legacy_shape()[0] * input_tensor_a.get_legacy_shape()[1] > 1 and
bmm_op.cpp:1111:                     input_tensor_b.get_legacy_shape()[0] * input_tensor_b.get_legacy_shape()[1] == 1);
bmm_op.cpp:1301:    // TODO: If input_tensor_a.get_legacy_shape()[0] * input_tensor_a.get_legacy_shape()[1] * ... except last two dimensions == 1, does matmuls work if
multi_core/bmm_op_multi_core.cpp:42:    auto num_output_tiles_total = cshape[0] * cshape[1] * cshape[2] * cshape[3] / TILE_HW;
multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp:453:    TT_FATAL((bcast_batch == false) or (ashape[0] == 1), "Bcast batch not supported for this parallelization");
single_core/bmm_op_single_core_tilize_untilize.cpp:172:    uint32_t in0_batch = in0.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:173:    uint32_t in0_channel = in0.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:174:    uint32_t in0_height = in0.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:175:    uint32_t in0_width = in0.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:176:    uint32_t in1_batch = in1.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:177:    uint32_t in1_channel = in1.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:178:    uint32_t in1_height = in1.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:179:    uint32_t in1_width = in1.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:187:        TT_FATAL(bias.get_legacy_shape()[3] == in1.get_legacy_shape()[3], "Bias shape mismatch");
single_core/bmm_op_single_core_tilize_untilize.cpp:196:        TT_FATAL(bias.get_legacy_shape()[2] % constants::TILE_HEIGHT == 0);
single_core/bmm_op_single_core_tilize_untilize.cpp:197:        TT_FATAL(bias.get_legacy_shape()[3] % constants::TILE_WIDTH == 0);
single_core/bmm_op_single_core_tilize_untilize.cpp:329:        bias_ntiles_w = bias.get_legacy_shape()[3] / constants::TILE_WIDTH;
single_core/bmm_op_single_core_tilize_untilize.cpp:604:    auto in0_batch = in0.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:605:    auto in0_channel = in0.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:606:    auto in0_height = in0.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:607:    auto in0_width = in0.get_legacy_shape()[3];
single_core/bmm_op_single_core_tilize_untilize.cpp:608:    auto in1_batch = in1.get_legacy_shape()[0];
single_core/bmm_op_single_core_tilize_untilize.cpp:609:    auto in1_channel = in1.get_legacy_shape()[1];
single_core/bmm_op_single_core_tilize_untilize.cpp:610:    auto in1_height = in1.get_legacy_shape()[2];
single_core/bmm_op_single_core_tilize_untilize.cpp:611:    auto in1_width = in1.get_legacy_shape()[3];
bbradelTT commented 4 months ago

The next set of issues are related to the call to bcast in ttnn/cpp/ttnn/operations/matmul.cpp

First, there were multiple index references other than -1 and -2.

Second the tt_metal path uses launch_with_autoformat which required 4D tensors: tt_eager/tt_dnn/op_library/auto_format.hpp: static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) { auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0]; auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1]; auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2]; auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3]; Shape padded_shape = {n, c, h, w}; return padded_shape; }
which is used by operation::launch_with_autoformat I switched to operations::primary::bcast

bbradelTT commented 4 months ago

I looked at what torch does for tensors of different ranks. The output tensor rank may also need to be adjusted to take into account ranks for both inputs since torch uses the larger rank, and currently rank is based on input 0 in bmm_op.

I ran the following:

a = [
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([2,1,64,512],dtype=torch.bfloat16),
torch.rand([64,512],dtype=torch.bfloat16),
torch.rand([1,64,512],dtype=torch.bfloat16),
torch.rand([1,1,1,1,64,512],dtype=torch.bfloat16),
torch.rand([1,1,2,1,64,512],dtype=torch.bfloat16),
]
b = [
torch.rand([1,1,512,64],dtype=torch.bfloat16),
torch.rand([1,512,64],dtype=torch.bfloat16),
torch.rand([2,512,64],dtype=torch.bfloat16),
torch.rand([1,2,512,64],dtype=torch.bfloat16),
]

for in1 in a:
    for in2 in b:
        print(f'matmul in1={in1.size()} in2={in2.size()}')
        c = torch.matmul(in1,in2)
        print(f'in1={in1.size()} in2={in2.size()} c={c.size()}')

The output is

matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([2, 1, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([2, 1, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 2, 64, 64])
matmul in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([2, 2, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 64, 64])
matmul in1=torch.Size([64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([2, 64, 64])
matmul in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 1, 1, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 1, 1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 1, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 1, 1, 2, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 1, 512, 64]) c=torch.Size([1, 1, 2, 1, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 512, 64]) c=torch.Size([1, 1, 2, 1, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([2, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c=torch.Size([1, 1, 2, 2, 64, 64])
matmul in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64])
in1=torch.Size([1, 1, 2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c=torch.Size([1, 1, 2, 2, 64, 64])
bbradelTT commented 4 months ago

Tensors beyond rank 4 are currently not allowed:

RuntimeError: TT_THROW @ ../ttnn/cpp/ttnn/validation.hpp:40: tt::exception
info:
ttnn.matmul: Tensor rank is not valid: rank is 6 but must be  2 <= rank <- 4
bbradelTT commented 4 months ago

For the reshape, the code should be able to handle different shapes the same way as pytorch with a few exceptions where pytorch takes the cross product of shapes and tt-metal will take the shape of the first input tensor:

shape equal False in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([2, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([2, 1, 64, 64]) shape equal False in1=torch.Size([2, 1, 64, 512]) in2=torch.Size([1, 2, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([2, 1, 64, 64]) shape equal False in1=torch.Size([1, 2, 64, 512]) in2=torch.Size([2, 1, 512, 64]) c torch.Size([2, 2, 64, 64]) d ttnn.Shape([1, 2, 64, 64])

Since tt-metal matmul either

Code:

import torch
import ttnn
import time
device_id = 0
device = ttnn.open_device(device_id=device_id)
ttnn.enable_program_cache(device)

a = [          
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([2,1,64,512],dtype=torch.bfloat16),
torch.rand([64,512],dtype=torch.bfloat16),
torch.rand([1,64,512],dtype=torch.bfloat16),
torch.rand([2,64,512],dtype=torch.bfloat16),
torch.rand([1,1,64,512],dtype=torch.bfloat16),
torch.rand([1,2,64,512],dtype=torch.bfloat16),
]              
b = [
torch.rand([1,1,512,64],dtype=torch.bfloat16),
torch.rand([1,512,64],dtype=torch.bfloat16),
torch.rand([512,64],dtype=torch.bfloat16),
torch.rand([2,512,64],dtype=torch.bfloat16),
torch.rand([1,2,512,64],dtype=torch.bfloat16),
torch.rand([2,1,512,64],dtype=torch.bfloat16),
]

def batch(t):  
    result = 1 
    for i in range(len(t) - 2):
        result *= t[i]
    return result

for in1 in a:  
    t1=ttnn.from_torch(in1, layout=ttnn.TILE_LAYOUT, device=device)
    batch_a = batch(in1.size())
    for in2 in b:
        batch_b = batch(in2.size())
        if batch_b > batch_a:
            continue
        len1 = len(in1.size())
        len2 = len(in2.size())
        print(f'in1={in1.size()} in2={in2.size()} len1={len1} len2={len2} in2.size()[0]={in2.size()[0]}')
        if len2  > len1 and in2.size()[0] > 1:
            continue
        t2=ttnn.from_torch(in2, layout=ttnn.TILE_LAYOUT, device=device)
        print(f'matmul in1={in1.size()} in2={in2.size()}')
        c = torch.matmul(in1,in2)
        print(f'torch matmul done in1={in1.size()} in2={in2.size()} c={c.size()}')
        d = ttnn.matmul(t1, t2)
        print(f'ttnn matmul done in1={in1.size()} in2={in2.size()} d={d}')
        print(f'shape equal {c.size()==d.shape} in1={in1.size()} in2={in2.size()} c {c.size()} d {d.shape}')
        shapes_equal = c.size()==d.shape
        ct=ttnn.from_torch(c, layout=ttnn.TILE_LAYOUT, device=device)
        if shapes_equal:
            dt = d.cpu().to_torch()
            diff = ((dt-c)/c).abs()
            max_pcc = 1.0 - diff.max()
            tensors_equal = max_pcc > 0.7
            print(f'tensors equal {tensors_equal} {max_pcc} {diff} ct {ct} d {d}')

ttnn.close_device(device)