secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
243 stars 106 forks source link

[Question]: About the support for efficient Winograd convolution in SPU #738

Closed warpoons closed 4 months ago

warpoons commented 5 months ago

Issue Type

Build/Install

Modules Involved

MPC protocol

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0.dev20240311

OS Platform and Distribution

Ubuntu 18.04.6 LTS by WSL

Python Version

3.10

Compiler Version

GCC 11.3.0

Current Behavior?

Hi dear SPU team,

In recent weeks I have been working on using SPU to evaluate the private inference cost of Winograd-based ConvNets, which can reduce the number of multiplications and thus a lightweight comm cost.

BTW, using Winograd for more efficient PPML is gradually emerging and has been proven to be an efficient way to improve the comm efficiency, including:

  1. Zeng W, Li M, Yang H, et al. Copriv: Network/protocol co-optimization for communication-efficient private inference[J]. Advances in Neural Information Processing Systems, 2023, 36: 78906-78925.
  2. Zeng W, Xu T, Li M, et al. EQO: Exploring Ultra-Efficient Private Inference with Winograd-Based Protocol and Quantization Co-Optimization[J]. arXiv preprint arXiv:2404.09404, 2024.

But I found that the reduction of multiplications will NOT make the comm truly decreases in SPU. In general, the respective comms of a single conv layer are: the stand conv: 759296 byte, Winograd-EWMM (element-wise matmul) conv: 6291456 byte, Winograd-GEMM (general matmul) conv: 3964928 byte. Apparently, the Winograd algorithm increases the comm by 8.28x and 5.22x, respectively. And the increased comm mainly comes from the frequent invocation of truncations. For details please see here.

So if now I want to modify the C++ backend op for Winograd in SPU, could you tell me where should I start and which piece should i start looking at in SPU? Thanks! (If possible, I am willing to contribute the support for Winograd to the SPU community.)

Standalone code to reproduce the issue

N/A

Relevant log output

N/A
anakinxc commented 5 months ago

Hi @warpoons

To change conv implementation, please check here

warpoons commented 5 months ago

Hi @anakinxc Thanks for your quick reponse!

Currently I have implemented the Wingrad conv class in Flax in the frontend of SPU but the comm improvement is far from expectation.

Should I directly re-implement the Winograd convolution in the libspu/kernel/hlo/convolution.cc you mentioned here?

If so, for the model definition such as ResNet18 in the frontend of SPU, should I retain the definition of conv layers as regular nn.conv or something else?

Thanks.

anakinxc commented 5 months ago

Should I directly re-implement the Winograd convolution in the libspu/kernel/hlo/convolution.cc you mentioned here?

yes

If so, for the model definition such as ResNet18 in the frontend of SPU, should I retain the definition of conv layers as regular nn.conv or something else?

regular conv should be fine

warpoons commented 4 months ago

Hi @anakinxc. Thanks for your respones and keeping this issue active.

In recent days, I have been trying to implement the C++ backend op for Winograd in SPU. But there are still some difficulties because of my unfamiliarity with C++. So I sincerely request your guidance and assistance. Thanks.

First of all, many apologies if these are trivial questions.

As previously pointed, I have re-implemented the convolution in libspu/kernel/hlo/convolution.cc by highly referencing the python/torch version here. And the modified convolution.cc is:

#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/permute.h" // Newly added for Winograd transformation

namespace spu::kernel::hlo {

// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
                         const spu::Value &kernel,
                         const ConvolutionConfig &config,
                         const Shape &result_shape) {
  SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
  // input  : (N, H, W, C)
  // kernel : (h, w, C, O)
  // output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1

  // Alias input dimensions.
  auto N = input.shape()[0];
  auto H = input.shape()[1];
  auto W = input.shape()[2];
  auto C = input.shape()[3];

  auto h = kernel.shape()[0];
  auto w = kernel.shape()[1];
  SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
  auto O = kernel.shape()[3];

  SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
  auto hh = result_shape[1];
  auto ww = result_shape[2];
  SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");

  SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
  int64_t sh = config.window_strides[0];
  int64_t sw = config.window_strides[1];

  SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
  SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);

  // For 3x3 kernel size, use Winograd convolution
  if (h == 3 && w == 3) {
    // Parameters
    auto r = h;  // kernel size, which is fixed to 3 here
    auto m = 2; // output tile size, which is fixed to 2 here
    auto a = m + r -1; // input tile size, which is fixed to 4 here
    auto T = hal::ceil((H - r + 1) / m); // number of tiles along height/weight
    auto P = N * T * T; // total number of tiles
    auto tile_stride = a - r + 1; // stride for extracting input tiles

    // Define Winograd transformation matrices
    auto G = { {1.0, 0.0, 0.0},
             {0.5, 0.5, 0.5},
             {0.5, -0.5, 0.5},
             {0.0, 0.0, 1.0} };
    auto G_T = { {1.0, 0.5, 0.5, 0.0},
               {0.0, 0.5, -0.5, 0.0},
               {0.0, 0.5, 0.5, 1.0} };
    auto B = { {1, 0, 0, 0},
             {0, 1, -1, 1},
             {-1, 1, 1, 0},
             {0, 0, 0, -1} };
    auto B_T = { {1, 0, -1, 0},
               {0, 1, 1, 0},
               {0, -1, 1, 0},
               {0, 1, 0, -1} };
    auto A = { {1, 0},
             {1, 1},
             {1, -1},
             {0, -1}};
    auto A_T = { {1, 1, 1, 0},
               {0, 1, -1, -1}};
    // Transform kernel to Winograd domain
    auto U = hal::matmul(ctx, G, hal::matmul(ctx, kernel, G_T));
    // Transform input to Winograd domain
    Value expanded;
    {
      std::vector<spu::Value> tiled_input;
      // Separate the input into axa tiles
      for (int64_t x = 0; x <= H - a; x += tile_stride) {
        for (int64_t y = 0; y <= W - a; y += tile_stride) {
          auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
          tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, tiled_input, 0);
      expanded = hal::reshape(ctx, stacked, {P, a, a, C});
    }
    auto V = hal::matmul(ctx, B_T, hal::matmul(ctx, expanded, B));

    // Transform Winograd input and kernel to GEMM format
    U = hal::permute(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
    V = hal::permute(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)

    // Do GEMM
    auto M = hal::matmul(ctx, U, V); // M (a, a, O, P)
    M = hal::permute(ctx, M, {3, 0, 1, 2}); // M (a, a, O, P) -> (P, a, a, O)

    // Transform Winograd output to regular output
    auto output_tile = hal::matmul(ctx, A_T, hal::matmul(ctx, M, A)); // output_tile (P, 2, 2, O)
    auto out_size = H - r + 1;
    auto result = hal::reshape(ctx, hal::permute(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});

    return result;

  } else {
    // Fallback, use im2col + dot to implement convolution
    // expand the image according to the kernel size.
    // assumption:
    // - padding is erased by some compiler pass.
    // - input  : NxHxWxC
    // - kernel : hxwxCxO
    Value expanded;
    {
      std::vector<spu::Value> images;
      for (int64_t x = 0; x <= H - h; x += sh) {
        for (int64_t y = 0; y <= W - w; y += sw) {
          auto window =
              hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
          images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, images, 1);
      SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
      expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
    }

    // TODO(jint): the below method is much slower then the code above, consider
    // to use slice+reshape+concat to rewrite expandWindow.
    //
    // std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
    // auto expanded = expandWindow(ctx, input,      // input
    //                                    {N, h, w, C},    // window_shape
    //                                    {1, sh, sw, 1},  // strides
    //                                    padding);

    // Now expanded shape is (N, hh*ww, h*w, C)
    SPU_ENFORCE_EQ(expanded.shape()[0], N);
    SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
    SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
    SPU_ENFORCE_EQ(expanded.shape()[3], C);

    // Reshape it to (N, hh, ww, h, w, C)
    expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});

    // Contract on h, w, C
    // expanded:  (N, hh, ww, h, w, C)
    // kernel:               (h, w, C, O)
    // result:    (N, hh, ww,          O)
    auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
    SPU_ENFORCE_EQ(result.shape()[0], N);
    SPU_ENFORCE_EQ(result.shape()[1], hh);
    SPU_ENFORCE_EQ(result.shape()[2], ww);
    SPU_ENFORCE_EQ(result.shape()[3], O);

    return result;
  }
}

}  // namespace spu::kernel::hlo

Please note that I only convert 3x3 convs into Winograd version and the others are retained as regular convs. The input/output tile size are fixed to 4/2 respectively.

This implemention is a peer-to-peer implemention as its torch version but it didn't work. I have some issues:

SPU is a grand and complex project, so it is quite difficult for me to fully understand it. If possible, could you please provide me with the modifications required for the above code as much detailed as possible. Thank you very much.

anakinxc commented 4 months ago
  • Winograd conv with GEMM requires the permute function (tensor.permute in torch) to transform Winograd input and kernel to GEMM format, I don't know if my use of permute function is correct?

Hi @warpoons

Nice work! Thanks

tpppppub commented 4 months ago

Hi @warpoons , SPU transpose does the same thing as torch.permute.

warpoons commented 4 months ago

Hi @anakinxc. Thanks for your respones!

for A B G, you need to build a proper Value with constant api.

Could you please kindly provide a specific example just about the definition of one of the Winograd transformation matrices A, B, G. For example, the instence to define:

G = { {1.0, 0.0, 0.0},
       {0.5, 0.5, 0.5},
       {0.5, -0.5, 0.5},
       {0.0, 0.0, 1.0} };

Thanks!

tensordot is a more general version of dot, which supports > 2D inputs. If you only need 2D dot, matmul is good enough.

Does this mean that tensordot supports both matmul and higher dimensional multiplications. Can I entirely use tensordot instead of matmul to avoid potential incompatibilities?


SPU transpose does the same thing as torch.permute.

Hi @tpppppub. Thanks for your respones!

Do you mind take a look at if the following usage of transpose is correct? I had searched the entire spu library but did not find an instructive example of using transpose. Thanks!

 // Transform Winograd input and kernel to General Matmul format
    U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
    V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)
warpoons commented 4 months ago

Hi @anakinxc @tpppppub . Mind take a look at my doubts? Thanks! :)

tpppppub commented 4 months ago

Your understanding of transpose is correct. Besides, the following tips are for your consideration.

  1. Most SPU APIs (hlo/hal) are consistent/similar with stablehlo. You can read the doc for reference.
  2. There are some comments on API declarations for explanations.
  3. You can also read the tests for references.
  4. Run your code to verify your guess.
warpoons commented 4 months ago

Hi @tpppppub. Thanks for your quick and detailed respone! I have modified the convolution.cc as follows:

#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/constants.h"  // Newly added for building winograd transformation matrices

namespace spu::kernel::hlo {

// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
                         const spu::Value &kernel,
                         const ConvolutionConfig &config,
                         const Shape &result_shape) {
  SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
  // input  : (N, H, W, C)
  // kernel : (h, w, C, O)
  // output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1

  // Alias input dimensions.
  auto N = input.shape()[0];
  auto H = input.shape()[1];
  auto W = input.shape()[2];
  auto C = input.shape()[3];

  auto h = kernel.shape()[0];
  auto w = kernel.shape()[1];
  SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
  auto O = kernel.shape()[3];

  SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
  auto hh = result_shape[1];
  auto ww = result_shape[2];
  SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");

  SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
  int64_t sh = config.window_strides[0];
  int64_t sw = config.window_strides[1];

  SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
  SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);

  // For 3x3 kernel size, use Winograd convolution
  if (h == 3 && w == 3) {
    // Parameters
    auto r = h;  // kernel size, which is fixed to 3 here
    auto m = 2; // output tile size, which is fixed to 2 here
    auto a = m + r -1; // input tile size, which is fixed to 4 here
    auto T = hal::ceil(ctx, (H - r + 1) / m); // number of tiles along height/weight
    auto P = N * T * T; // total number of tiles
    auto tile_stride = a - r + 1; // stride for extracting input tiles

    // Define Winograd transformation matrices
    std::vector<float> G_coefficients = { 1.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0.0, 0.0, 1.0 };
    std::vector<float> G_T_coefficients = { 1.0, 0.5, 0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.5, 1.0 };
    std::vector<float> B_coefficients = { 1, 0, 0, 0, 0, 1, -1, 1, -1, 1, 1, 0, 0, 0, 0, -1 };
    std::vector<float> B_T_coefficients = { 1, 0, -1, 0, 0, 1, 1, 0, 0, -1, 1, 0, 0, 1, 0, -1 };
    std::vector<float> A_coefficients = { 1, 0, 1, 1, 1, -1, 0, -1};
    std::vector<float> A_T_coefficients = { 1, 1, 1, 0, 0, 1, -1, -1};

    auto G = spu::kernel::hal::constant(ctx, G_coefficients, input.dtype(), {4, 3});
    auto G_T = spu::kernel::hal::constant(ctx, G_T_coefficients, input.dtype(), {3, 4});
    auto B = spu::kernel::hal::constant(ctx, B_coefficients, input.dtype(), {4, 4});
    auto B_T = spu::kernel::hal::constant(ctx, B_T_coefficients, input.dtype(), {4, 4});
    auto A = spu::kernel::hal::constant(ctx, A_coefficients, input.dtype(), {4, 2});
    auto A_T = spu::kernel::hal::constant(ctx, A_T_coefficients, input.dtype(), {2, 4});

    // Transform kernel to Winograd domain
    auto U = hal::matmul(ctx, G, hal::matmul(ctx, kernel, G_T));
    // Transform input to Winograd domain
    Value expanded;
    {
      std::vector<spu::Value> tiled_input;
      // Separate the input into axa tiles
      for (int64_t x = 0; x <= H - a; x += tile_stride) {
        for (int64_t y = 0; y <= W - a; y += tile_stride) {
          auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
          tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, tiled_input, 0);
      expanded = hal::reshape(ctx, stacked, {P, a, a, C});
    }
    auto V = hal::matmul(ctx, B_T, hal::matmul(ctx, expanded, B));

    // Transform Winograd input and kernel to General Matmul format
    U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
    V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)

    // Do General Matmul
    auto M = hal::matmul(ctx, U, V); // M (a, a, O, P)
    M = hal::transpose(ctx, M, {3, 0, 1, 2}); // M (a, a, O, P) -> (P, a, a, O)

    // Transform Winograd output to regular output
    auto output_tile = hal::matmul(ctx, A_T, hal::matmul(ctx, M, A)); // output_tile (P, 2, 2, O)
    auto out_size = H - r + 1;
    auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});

    return result;

  } else {
    // Fallback, use im2col + dot to implement convolution
    // expand the image according to the kernel size.
    // assumption:
    // - padding is erased by some compiler pass.
    // - input  : NxHxWxC
    // - kernel : hxwxCxO
    Value expanded;
    {
      std::vector<spu::Value> images;
      for (int64_t x = 0; x <= H - h; x += sh) {
        for (int64_t y = 0; y <= W - w; y += sw) {
          auto window =
              hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
          images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, images, 1);
      SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
      expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
    }

    // TODO(jint): the below method is much slower then the code above, consider
    // to use slice+reshape+concat to rewrite expandWindow.
    //
    // std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
    // auto expanded = expandWindow(ctx, input,      // input
    //                                    {N, h, w, C},    // window_shape
    //                                    {1, sh, sw, 1},  // strides
    //                                    padding);

    // Now expanded shape is (N, hh*ww, h*w, C)
    SPU_ENFORCE_EQ(expanded.shape()[0], N);
    SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
    SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
    SPU_ENFORCE_EQ(expanded.shape()[3], C);

    // Reshape it to (N, hh, ww, h, w, C)
    expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});

    // Contract on h, w, C
    // expanded:  (N, hh, ww, h, w, C)
    // kernel:               (h, w, C, O)
    // result:    (N, hh, ww,          O)
    auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
    SPU_ENFORCE_EQ(result.shape()[0], N);
    SPU_ENFORCE_EQ(result.shape()[1], hh);
    SPU_ENFORCE_EQ(result.shape()[2], ww);
    SPU_ENFORCE_EQ(result.shape()[3], O);

    return result;
  }
}

}  // namespace spu::kernel::hlo

After modifying the the convolution.cc file, I run the command bazel run -c opt //examples/python/utils:nodectl -- --configpwd/examples/python/ml/flax_resnet18/2pc.json up but unexpectedly got errors like:

ERROR: /home/warpoons/Desktop/spu-0.9.1b0/libspu/kernel/hlo/BUILD.bazel:152:15: Compiling libspu/kernel/hlo/convolution.cc failed: (Exit 1): gcc failed: error executing command (from target //libspu/kernel/hlo:convolution) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 179 arguments skipped)

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
libspu/kernel/hlo/convolution.cc: In function 'spu::Value spu::kernel::hlo::Convolution2D(spu::SPUContext*, const spu::Value&, const spu::Value&, const spu::kernel::hlo::ConvolutionConfig&, const spu::Shape&)':
libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'
   62 |     auto T = hal::ceil(ctx, (H - r + 1) / m); // number of tiles along height/weight
      |                             ~~~~~~~~~~~~^~~
In file included from libspu/kernel/hlo/convolution.cc:17:
./libspu/kernel/hal/polymorphic.h:88:42: note: in passing argument 2 of 'spu::Value spu::kernel::hal::ceil(spu::SPUContext*, const spu::Value&)'
   88 | Value ceil(SPUContext* ctx, const Value& in);
      |                             ~~~~~~~~~~~~~^~
libspu/kernel/hlo/convolution.cc:124:30: error: invalid initialization of reference of type 'const spu::Shape&' from expression of type '<brace-enclosed initializer list>'
  124 |       expanded = hal::reshape(ctx, stacked, {P, a, a, C});
      |                  ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from libspu/kernel/hlo/convolution.cc:18:
./libspu/kernel/hal/shape_ops.h:36:62: note: in passing argument 3 of 'spu::Value spu::kernel::hal::reshape(spu::SPUContext*, const spu::Value&, const spu::Shape&)'
   36 | Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape);
      |                                                 ~~~~~~~~~~~~~^~~~~~~~
libspu/kernel/hlo/convolution.cc:139:69: error: invalid initialization of reference of type 'const spu::Shape&' from expression of type '<brace-enclosed initializer list>'
  139 |     auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});
      |                                                         ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from libspu/kernel/hlo/convolution.cc:18:
./libspu/kernel/hal/shape_ops.h:36:62: note: in passing argument 3 of 'spu::Value spu::kernel::hal::reshape(spu::SPUContext*, const spu::Value&, const spu::Shape&)'
   36 | Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape);
      |                                                 ~~~~~~~~~~~~~^~~~~~~~
Target //examples/python/utils:nodectl failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 116.662s, Critical Path: 47.96s
INFO: 418 processes: 15 internal, 403 linux-sandbox.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target

Would you mind take a look at what went wrong there and feasible solutions? Thanks! :)

tpppppub commented 4 months ago

libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'

hal::ceil accepts the type spu::Value as argument, not builtin scalar types.

warpoons commented 4 months ago

libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'

hal::ceil accepts the type spu::Value as argument, not builtin scalar types.

Hi @tpppppub. Thanks for pointing out that! In the modified version, I have used auto T = input.shape()[1] / 2 - 1; without ceil and the SPU backend runtime can be successfully launched now, but when I tried to test the inference of ResNet18 with this new convolution.cc, the errors happened:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Traceback (most recent call last): File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 142, in run_on_spu(inputs, params) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 118, in run_on_spu out = infer(params, inputs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in call results = [future.result() for future in futures] File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in results = [future.result() for future in futures] File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.__get_result() File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, self.kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 247, in run return self._call(self._stub.Run, fn, *args, *kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 240, in _call raise Exception("remote exception", result) Exception: ('remote exception', Exception('Traceback (most recent call last):\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run\n ret_objs = fn(self, args, kwargs)\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run\n rt.run(spu_exec)\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run\n return self._vm.Run(executable.SerializeToString())\nRuntimeError: what: \n\t[Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). \nStacktrace:\n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\nstacktrace: \n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\n\n'))


- And the backend runtime reported something like:

[2024-07-16 17:19:27,187] [ForkServerProcess-1] Traceback (most recent call last): File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run ret_objs = fn(self, *args, **kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run rt.run(spu_exec) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run return self._vm.Run(executable.SerializeToString()) RuntimeError: what: [Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). Stacktrace:

0 spu::kernel::hal::f_mmul()+0x735cd124b610

1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2

2 spu::kernel::hal::matmul()+0x735cd122df4d

3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee

4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e

5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3

6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a

7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4

8 spu::device::runBlock()+0x735cd0c66a3d

9 spu::device::runRegion()+0x735cd0c68ae3

10 spu::device::executeImpl()+0x735cd059f0e7

11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc

12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d

13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006

14 cfunction_call+0x4fdc87


Could you mind take a look at what were wrong here and feasible solutions? Sorry for taking your time. Thanks! :)
anakinxc commented 4 months ago

libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int' hal::ceil accepts the type spu::Value as argument, not builtin scalar types.

Hi @tpppppub. Thanks for pointing out that! In the modified version, I have used auto T = input.shape()[1] / 2 - 1; without ceil and the SPU backend runtime can be successfully launched now, but when I tried to test the inference of ResNet18 with this new convolution.cc, the errors happened:

  • By running the flax_resnet_inference, I got:
INFO: Running command line: bazel-bin/examples/python/ml/flax_resnet18/flax_resnet_inference --config /home/warpoons/Desktop/spu-0.9.1b0/examples/python/ml/flax_resnet18/2pc.json
Run on SPU
------

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 142, in <module>
    run_on_spu(inputs, params)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 118, in run_on_spu
    out = infer(params, inputs)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in __call__
    results = [future.result() for future in futures]
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in <listcomp>
    results = [future.result() for future in futures]
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 247, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 240, in _call
    raise Exception("remote exception", result)
Exception: ('remote exception', Exception('Traceback (most recent call last):\n  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run\n    ret_objs = fn(self, *args, **kwargs)\n  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run\n    rt.run(spu_exec)\n  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run\n    return self._vm.Run(executable.SerializeToString())\nRuntimeError: what: \n\t[Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). \nStacktrace:\n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\nstacktrace: \n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\n\n'))
  • And the backend runtime reported something like:
[2024-07-16 17:19:27,187] [ForkServerProcess-1] Traceback (most recent call last):
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run
    ret_objs = fn(self, *args, **kwargs)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run
    rt.run(spu_exec)
  File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run
    return self._vm.Run(executable.SerializeToString())
RuntimeError: what: 
        [Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). 
Stacktrace:
#0 spu::kernel::hal::f_mmul()+0x735cd124b610
#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2
#2 spu::kernel::hal::matmul()+0x735cd122df4d
#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee
#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e
#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3
#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a
#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4
#8 spu::device::runBlock()+0x735cd0c66a3d
#9 spu::device::runRegion()+0x735cd0c68ae3
#10 spu::device::executeImpl()+0x735cd059f0e7
#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc
#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d
#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006
#14 cfunction_call+0x4fdc87

Could you mind take a look at what were wrong here and feasible solutions? Sorry for taking your time. Thanks! :)

matmul only supports < 2D dot. Seems you are doing 4D dot

warpoons commented 4 months ago

Hi @anakinxc. Thanks for your reminder! I realized that tensordot should be used here instead of matmul. After several necessary modifications, the code has been successfully executed without any errors.

#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/constants.h"  // Newly added for building winograd transformation matrices

namespace spu::kernel::hlo {

// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
                         const spu::Value &kernel,
                         const ConvolutionConfig &config,
                         const Shape &result_shape) {
  SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
  // input  : (N, H, W, C)
  // kernel : (h, w, C, O)
  // output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1

  // Alias input dimensions.
  auto N = input.shape()[0];
  auto H = input.shape()[1];
  auto W = input.shape()[2];
  auto C = input.shape()[3];

  auto h = kernel.shape()[0];
  auto w = kernel.shape()[1];
  SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
  auto O = kernel.shape()[3];

  SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
  auto hh = result_shape[1];
  auto ww = result_shape[2];
  SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");

  SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
  int64_t sh = config.window_strides[0];
  int64_t sw = config.window_strides[1];

  SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
  SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);

  // For 3x3 kernel size with 1 stride, use Winograd convolution
  if (h == 3 && w == 3 && config.window_strides[0] == 1 && config.window_strides[1] == 1) {

    // Parameters
    auto r = kernel.shape()[0];  // kernel size, which is fixed to 3 here
    auto m = 2; // output tile size, which is fixed to 2 here
    auto a = m + r -1; // input tile size, which is fixed to 4 here
    auto T = input.shape()[1] / 2 - 1;
    auto P = N * T * T; // total number of tiles
    auto tile_stride = a - r + 1; // stride for extracting input tiles

    // Define Winograd transformation matrices
    std::vector<float> G_coefficients = { 1.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0.0, 0.0, 1.0 };
    std::vector<float> G_T_coefficients = { 1.0, 0.5, 0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.5, 1.0 };
    std::vector<float> B_coefficients = { 1, 0, 0, 0, 0, 1, -1, 1, -1, 1, 1, 0, 0, 0, 0, -1 };
    std::vector<float> B_T_coefficients = { 1, 0, -1, 0, 0, 1, 1, 0, 0, -1, 1, 0, 0, 1, 0, -1 };
    std::vector<float> A_coefficients = { 1, 0, 1, 1, 1, -1, 0, -1};
    std::vector<float> A_T_coefficients = { 1, 1, 1, 0, 0, 1, -1, -1};

    auto G = spu::kernel::hal::constant(ctx, G_coefficients, input.dtype(), {4, 3});
    auto G_T = spu::kernel::hal::constant(ctx, G_T_coefficients, input.dtype(), {3, 4});
    auto B = spu::kernel::hal::constant(ctx, B_coefficients, input.dtype(), {4, 4});
    auto B_T = spu::kernel::hal::constant(ctx, B_T_coefficients, input.dtype(), {4, 4});
    auto A = spu::kernel::hal::constant(ctx, A_coefficients, input.dtype(), {4, 2});
    auto A_T = spu::kernel::hal::constant(ctx, A_T_coefficients, input.dtype(), {2, 4});

    // Transform kernel to Winograd domain  Wrong here, try tensordot
    auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0});
    U = hal::transpose(ctx, U, {0, 3, 1, 2});

    // Transform input to Winograd domain
    Value expanded;
    {
      std::vector<spu::Value> tiled_input;
      // Separate the input into axa tiles
      for (int64_t x = 0; x <= H - a; x += tile_stride) {
        for (int64_t y = 0; y <= W - a; y += tile_stride) {
          auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
          tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, tiled_input, 0);
      expanded = hal::reshape(ctx, stacked, {P, a, a, C});

    }
    auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1});
    V = hal::transpose(ctx, V, {1, 0, 3, 2});

    // Transform Winograd input and kernel to General Matmul format
    U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
    V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)

    // Do General Matmul
    auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2});
    M = hal::transpose(ctx, M, {3, 0, 2, 1}); // M -> (P, a, a, O)

    // Transform Winograd output to regular output
    auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // output_tile (P, 2, 2, O)
    output_tile = hal::transpose(ctx, output_tile, {2, 1, 0, 3});
    auto out_size = H - r + 1;
    auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});

    return result;

  } else {
    // Fallback, use im2col + dot to implement convolution
    // expand the image according to the kernel size.
    // assumption:
    // - padding is erased by some compiler pass.
    // - input  : NxHxWxC
    // - kernel : hxwxCxO
    Value expanded;
    {
      std::vector<spu::Value> images;
      for (int64_t x = 0; x <= H - h; x += sh) {
        for (int64_t y = 0; y <= W - w; y += sw) {
          auto window =
              hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
          images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
        }
      }
      auto stacked = hal::concatenate(ctx, images, 1);
      SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
      expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
    }

    // TODO(jint): the below method is much slower then the code above, consider
    // to use slice+reshape+concat to rewrite expandWindow.
    //
    // std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
    // auto expanded = expandWindow(ctx, input,      // input
    //                                    {N, h, w, C},    // window_shape
    //                                    {1, sh, sw, 1},  // strides
    //                                    padding);

    // Now expanded shape is (N, hh*ww, h*w, C)
    SPU_ENFORCE_EQ(expanded.shape()[0], N);
    SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
    SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
    SPU_ENFORCE_EQ(expanded.shape()[3], C);

    // Reshape it to (N, hh, ww, h, w, C)
    expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});

    // Contract on h, w, C
    // expanded:  (N, hh, ww, h, w, C)
    // kernel:               (h, w, C, O)
    // result:    (N, hh, ww,          O)
//    std::cout << "The shape of expanded is: " << expanded.shape() << std::endl;
//    std::cout << "The shape of kernel is: " << kernel.shape() << std::endl;
    auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
    std::cout << "The shape of std conv result is: " << result.shape() << std::endl;
    SPU_ENFORCE_EQ(result.shape()[0], N);
    SPU_ENFORCE_EQ(result.shape()[1], hh);
    SPU_ENFORCE_EQ(result.shape()[2], ww);
    SPU_ENFORCE_EQ(result.shape()[3], O);

    return result;
  }
}

}  // namespace spu::kernel::hlo

But the profile of ResNet18 on CIFAR-100 was still wired:

[2024-07-18 17:18:03.364] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.6291e-05s, execution took 0.442374948s, output processing took 3.8402e-05s, total time 0.442429641s.
[2024-07-18 17:18:03.364] [info] [api.cc:211] HLO profiling: total time 0.43840665800000006
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.173768735s, send bytes 47269376 recv bytes 47269376
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.089064589s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.06104458s, send bytes 2906624 recv bytes 2906624
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.060172622s, send bytes 3673792 recv bytes 3673792
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.038431921s, send bytes 2392064 recv bytes 2392064
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.005943561s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003742128s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002317673s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.001958798s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001353351s, send bytes 414496 recv bytes 414496
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000157907s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000149092s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000109685s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.4707e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.5505e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 3.7019e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4785e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:211] HAL profiling: total time 0.39880832400000005
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_tensordot, executed 20 times, duration 0.155717562s, send bytes 47269376 recv bytes 47269376
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.089025885s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_less, executed 41 times, duration 0.081797061s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046109298s, send bytes 2791424 recv bytes 2791424
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.013188393s, send bytes 882368 recv bytes 882368
[2024-07-18 17:18:03.365] [info] [api.cc:214] - _mux, executed 4 times, duration 0.005991638s, send bytes 557056 recv bytes 557056
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003328688s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002271631s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.00134875s, send bytes 414496 recv bytes 414496
[2024-07-18 17:18:03.365] [info] [api.cc:214] - seal, executed 2 times, duration 2.9418e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:211] MPC profiling: total time 0.4245964730000001
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.12947814s, send bytes 47277568 recv bytes 47277568
[2024-07-18 17:18:03.365] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.07886524s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:18:03.365] [info] [api.cc:214] - trunc_a, executed 326 times, duration 0.068162238s, send bytes 1693472 recv bytes 1693472
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.047058035s, send bytes 3550208 recv bytes 3550208
[2024-07-18 17:18:03.365] [info] [api.cc:214] - a2b, executed 20 times, duration 0.021607477s, send bytes 998400 recv bytes 998400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.018228241s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018212783s, send bytes 147072 recv bytes 147072
[2024-07-18 17:18:03.365] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.014937339s, send bytes 422400 recv bytes 422400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pad, executed 15 times, duration 0.005907683s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004702058s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - extract_slice, executed 4968 times, duration 0.004617872s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - not_a, executed 85 times, duration 0.003940147s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - reshape, executed 4936 times, duration 0.002173678s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001375718s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.00123122s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000840037s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000659155s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000602192s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000551214s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - transpose, executed 91 times, duration 0.000488421s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - make_p, executed 515 times, duration 0.000357092s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000202838s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000112163s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.4106e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.2188e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.1299e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - not_p, executed 20 times, duration 4.1154e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - p2a, executed 2 times, duration 2.6339e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0406e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:224] Link details: total send bytes 58830752, recv bytes 58830752, send actions 1231

Additionally, I noted that in the Winograd conv, three hal::tensordot are used where two for the input/kernel transformation and one for the GEMM, where the regular only use hal::tensordot once.

Remember that for a fixed model with known weights, the kernel transformation can be done offline. Therefore, I replace the hal::tensordot in kernel transformation with constant all-ones mat to simulate this situation. And the profile is:

[2024-07-18 17:29:36.238] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.7796e-05s, execution took 0.438780767s, output processing took 3.7401e-05s, total time 0.438835964s.
[2024-07-18 17:29:36.238] [info] [api.cc:211] HLO profiling: total time 0.4354281060000001
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.172220784s, send bytes 43451904 recv bytes 43451904
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.089114364s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.060312438s, send bytes 2906624 recv bytes 2906624
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.058106095s, send bytes 3673792 recv bytes 3673792
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.038393581s, send bytes 2392064 recv bytes 2392064
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.007641689s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003542044s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002279123s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.001889133s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001350285s, send bytes 414496 recv bytes 414496
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000147391s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000138275s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000107958s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.3047e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.1171e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 3.5468e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.526e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:211] HAL profiling: total time 0.39483937999999996
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_tensordot, executed 48 times, duration 0.154494692s, send bytes 43451904 recv bytes 43451904
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.089073318s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_less, executed 41 times, duration 0.081052562s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.044366134s, send bytes 2791424 recv bytes 2791424
[2024-07-18 17:29:36.238] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.012888425s, send bytes 882368 recv bytes 882368
[2024-07-18 17:29:36.238] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006011205s, send bytes 557056 recv bytes 557056
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003353158s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002226449s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001345725s, send bytes 414496 recv bytes 414496
[2024-07-18 17:29:36.238] [info] [api.cc:214] - seal, executed 2 times, duration 2.7712e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:211] MPC profiling: total time 0.421531365
[2024-07-18 17:29:36.238] [info] [api.cc:214] - mmul_aa, executed 14 times, duration 0.107591987s, send bytes 40936960 recv bytes 40936960
[2024-07-18 17:29:36.238] [info] [api.cc:214] - trunc_a, executed 354 times, duration 0.083716544s, send bytes 4216608 recv bytes 4216608
[2024-07-18 17:29:36.239] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.078090014s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.045529196s, send bytes 3550208 recv bytes 3550208
[2024-07-18 17:29:36.239] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022105726s, send bytes 998400 recv bytes 998400
[2024-07-18 17:29:36.239] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018533551s, send bytes 147072 recv bytes 147072
[2024-07-18 17:29:36.239] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.018099175s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.01441098s, send bytes 422400 recv bytes 422400
[2024-07-18 17:29:36.239] [info] [api.cc:214] - pad, executed 15 times, duration 0.007607729s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004777058s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.004058945s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - not_a, executed 85 times, duration 0.003906662s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - reshape, executed 4799 times, duration 0.003217838s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mmul_ap, executed 35 times, duration 0.002834911s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001369924s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001247104s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - transpose, executed 252 times, duration 0.000901744s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000813418s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000654728s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000631001s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000514174s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - make_p, executed 515 times, duration 0.000340834s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000188551s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000107943s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.318e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.9109e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_pp, executed 20 times, duration 5.5609e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - not_p, executed 20 times, duration 3.9903e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - p2a, executed 2 times, duration 2.3822e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0005e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:224] Link details: total send bytes 55013280, recv bytes 55013280, send actions 1252

The total comm is slightly decreased but still far from expectations. So confused!

Would you mind take a look at this and is there any suggestions? BTW, I used 2PC.json with "protocol": "SEMI2K", "field": "FM64" in this test, Thanks! :)

anakinxc commented 4 months ago

Hi @warpoons

Please try to benchmark with experimental_disable_mmul_split set to True.

Looks like your implementation requires 7 tensordot where the original implementation requires only 1.

Tensordot with non-integer inputs will introduce a truncation operation which is expensive and needs communication.

One thing you can try is reduce number of truncation, for example, trunc(dot(x, y) + trunc(dot(y, z)) can be simplified into trunc(dot(x,y) + dot(y,z))

warpoons commented 4 months ago

Hi @warpoons

Please try to benchmark with experimental_disable_mmul_split set to True.

Looks like your implementation requires 7 tensordot where the original implementation requires only 1.

Tensordot with non-integer inputs will introduce a truncation operation which is expensive and needs communication.

One thing you can try is reduce number of truncation, for example, trunc(dot(x, y) + trunc(dot(y, z)) can be simplified into trunc(dot(x,y) + dot(y,z))

Hi @anakinx. Thanks for your suggestion! I have benchmarked with "experimental_disable_mmul_split": true in 2pc.json. And the profile of Winograd conv is:

[2024-07-20 09:58:35.945] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.5933e-05s, execution took 0.580790941s, output processing took 3.9173e-05s, total time 0.580846047s.
[2024-07-20 09:58:35.946] [info] [api.cc:211] HLO profiling: total time 0.5768848610000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.302007006s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.091915749s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.064329796s, send bytes 2906624 recv bytes 2906624
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.062256372s, send bytes 3673792 recv bytes 3673792
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.039781986s, send bytes 2392064 recv bytes 2392064
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.006391185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003279638s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002689619s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.002069318s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001498637s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000175185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000169835s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000117636s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.9414e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.7852e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 4.0959e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4674e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] HAL profiling: total time 0.536832459
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_tensordot, executed 62 times, duration 0.284671008s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.091869937s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_less, executed 41 times, duration 0.085656933s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046768409s, send bytes 2791424 recv bytes 2791424
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.01449265s, send bytes 882368 recv bytes 882368
[2024-07-20 09:58:35.946] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006170996s, send bytes 557056 recv bytes 557056
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003050888s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002625307s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001493948s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - seal, executed 2 times, duration 3.2383e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] MPC profiling: total time 0.5615995250000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - trunc_a, executed 368 times, duration 0.165186115s, send bytes 18896672 recv bytes 18896672
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.152509485s, send bytes 50046464 recv bytes 50046464
[2024-07-20 09:58:35.946] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.082573462s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.04785361s, send bytes 3550208 recv bytes 3550208
[2024-07-20 09:58:35.946] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022759137s, send bytes 998400 recv bytes 998400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018657123s, send bytes 147072 recv bytes 147072
[2024-07-20 09:58:35.946] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.017694911s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.015000716s, send bytes 422400 recv bytes 422400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reshape, executed 4841 times, duration 0.007828969s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pad, executed 15 times, duration 0.006348377s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004482565s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.00437641s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_a, executed 85 times, duration 0.004254609s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_ap, executed 42 times, duration 0.004156978s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001468896s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001353077s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - transpose, executed 287 times, duration 0.001180328s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000877519s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000723871s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000653156s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000577432s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - make_p, executed 515 times, duration 0.0004292s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000221643s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000124417s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.8349e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.7871e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.7668e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_p, executed 20 times, duration 4.5916e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - p2a, executed 2 times, duration 2.7477e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0238e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:224] Link details: total send bytes 78802848, recv bytes 78802848, send actions 1273

Seem to be no difference here.

And for the number of tensordot, there are indeed 7 tensordot in a Winograd conv, where they specifically are:

For 6 fixed pre-defined 2D transformation matrices: G = { {1.0, 0.0, 0.0}, {0.5, 0.5, 0.5}, {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0} }; G_T = { {1.0, 0.5, 0.5, 0.0}, {0.0, 0.5, -0.5, 0.0}, {0.0, 0.5, 0.5, 1.0} }; B = { {1, 0, 0, 0}, {0, 1, -1, 1}, {-1, 1, 1, 0}, {0, 0, 0, -1} }; B_T = { {1, 0, -1, 0}, {0, 1, 1, 0}, {0, -1, 1, 0}, {0, 1, 0, -1} }; A = { {1, 0}, {1, 1}, {1, -1}, {0, -1}}; A_T = { {1, 1, 1, 0}, {0, 1, -1, -1}};

// Transform 4D (h, w, inCh, OutCh) kernel to Winograd domain auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0}); // 2 tensordot here // Transform 4D (N, H, W, inCh) input to Winograd domain auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1}); // 2 tensordot here // Do General Matmul auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2}); // 1 tensordot here // Transform Winograd output back to regular output auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // 2 tensordot here

Therefore, there are total 2+2+1+2=7 tensordot for one Winograd conv. This might be the reason for high trunc and total comm.

I have also merged all the tensordot into one expression but the profile didn't change.

Are there any feasible optimizations? Thanks! :)

anakinxc commented 4 months ago

Hi @warpoons Please try to benchmark with experimental_disable_mmul_split set to True. Looks like your implementation requires 7 tensordot where the original implementation requires only 1. Tensordot with non-integer inputs will introduce a truncation operation which is expensive and needs communication. One thing you can try is reduce number of truncation, for example, trunc(dot(x, y) + trunc(dot(y, z)) can be simplified into trunc(dot(x,y) + dot(y,z))

Hi @AnakinX. Thanks for your suggestion! I have benchmarked with "experimental_disable_mmul_split": true in 2pc.json. And the profile of Winograd conv is:

[2024-07-20 09:58:35.945] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.5933e-05s, execution took 0.580790941s, output processing took 3.9173e-05s, total time 0.580846047s.
[2024-07-20 09:58:35.946] [info] [api.cc:211] HLO profiling: total time 0.5768848610000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.302007006s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.091915749s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.064329796s, send bytes 2906624 recv bytes 2906624
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.062256372s, send bytes 3673792 recv bytes 3673792
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.039781986s, send bytes 2392064 recv bytes 2392064
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.006391185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003279638s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002689619s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.002069318s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001498637s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000175185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000169835s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000117636s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.9414e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.7852e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 4.0959e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4674e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] HAL profiling: total time 0.536832459
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_tensordot, executed 62 times, duration 0.284671008s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.091869937s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_less, executed 41 times, duration 0.085656933s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046768409s, send bytes 2791424 recv bytes 2791424
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.01449265s, send bytes 882368 recv bytes 882368
[2024-07-20 09:58:35.946] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006170996s, send bytes 557056 recv bytes 557056
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003050888s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002625307s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001493948s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - seal, executed 2 times, duration 3.2383e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] MPC profiling: total time 0.5615995250000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - trunc_a, executed 368 times, duration 0.165186115s, send bytes 18896672 recv bytes 18896672
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.152509485s, send bytes 50046464 recv bytes 50046464
[2024-07-20 09:58:35.946] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.082573462s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.04785361s, send bytes 3550208 recv bytes 3550208
[2024-07-20 09:58:35.946] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022759137s, send bytes 998400 recv bytes 998400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018657123s, send bytes 147072 recv bytes 147072
[2024-07-20 09:58:35.946] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.017694911s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.015000716s, send bytes 422400 recv bytes 422400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reshape, executed 4841 times, duration 0.007828969s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pad, executed 15 times, duration 0.006348377s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004482565s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.00437641s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_a, executed 85 times, duration 0.004254609s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_ap, executed 42 times, duration 0.004156978s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001468896s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001353077s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - transpose, executed 287 times, duration 0.001180328s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000877519s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000723871s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000653156s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000577432s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - make_p, executed 515 times, duration 0.0004292s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000221643s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000124417s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.8349e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.7871e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.7668e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_p, executed 20 times, duration 4.5916e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - p2a, executed 2 times, duration 2.7477e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0238e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:224] Link details: total send bytes 78802848, recv bytes 78802848, send actions 1273

Seem to be no difference here.

And for the number of tensordot, there are indeed 7 tensordot in a Winograd conv, where they specifically are:

For 6 fixed pre-defined 2D transformation matrices: G = { {1.0, 0.0, 0.0}, {0.5, 0.5, 0.5}, {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0} }; G_T = { {1.0, 0.5, 0.5, 0.0}, {0.0, 0.5, -0.5, 0.0}, {0.0, 0.5, 0.5, 1.0} }; B = { {1, 0, 0, 0}, {0, 1, -1, 1}, {-1, 1, 1, 0}, {0, 0, 0, -1} }; B_T = { {1, 0, -1, 0}, {0, 1, 1, 0}, {0, -1, 1, 0}, {0, 1, 0, -1} }; A = { {1, 0}, {1, 1}, {1, -1}, {0, -1}}; A_T = { {1, 1, 1, 0}, {0, 1, -1, -1}}; // Transform 4D (h, w, inCh, OutCh) kernel to Winograd domain auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0}); // 2 tensordot here // Transform 4D (N, H, W, inCh) input to Winograd domain auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1}); // 2 tensordot here // Do General Matmul auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2}); // 1 tensordot here // Transform Winograd output back to regular output auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // 2 tensordot here

Therefore, there are total 2+2+1+2=7 tensordot for one Winograd conv. This might be the reason for high trunc and total comm.

I have also merged all the tensordot into one expression but the profile didn't change.

Are there any feasible optimizations? Thanks! :)

Merge into one expression does not reduce number of calls.

Try to figure out how many truncations you actually need might be a better idea.

warpoons commented 4 months ago

Thank you @anakinxc. As you previously said, matmul is suitable for 2D dot while tensorsot supports >2D dot. An additional question is that does matmul also introduce 1 trunc just like tensorsot?

warpoons commented 4 months ago

Thanks for the patient guidance and discussion. My problem has been solved. This issue can be closed as completed. Thansk for all!