Closed warpoons closed 4 months ago
Hi @anakinxc Thanks for your quick reponse!
Currently I have implemented the Wingrad conv class in Flax in the frontend of SPU but the comm improvement is far from expectation.
Should I directly re-implement the Winograd convolution in the libspu/kernel/hlo/convolution.cc you mentioned here?
If so, for the model definition such as ResNet18 in the frontend of SPU, should I retain the definition of conv layers as regular nn.conv or something else?
Thanks.
Should I directly re-implement the Winograd convolution in the libspu/kernel/hlo/convolution.cc you mentioned here?
yes
If so, for the model definition such as ResNet18 in the frontend of SPU, should I retain the definition of conv layers as regular nn.conv or something else?
regular conv should be fine
Hi @anakinxc. Thanks for your respones and keeping this issue active.
In recent days, I have been trying to implement the C++ backend op for Winograd in SPU. But there are still some difficulties because of my unfamiliarity with C++. So I sincerely request your guidance and assistance. Thanks.
First of all, many apologies if these are trivial questions.
As previously pointed, I have re-implemented the convolution in libspu/kernel/hlo/convolution.cc by highly referencing the python/torch version here. And the modified convolution.cc is:
#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/permute.h" // Newly added for Winograd transformation
namespace spu::kernel::hlo {
// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
const spu::Value &kernel,
const ConvolutionConfig &config,
const Shape &result_shape) {
SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
// input : (N, H, W, C)
// kernel : (h, w, C, O)
// output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1
// Alias input dimensions.
auto N = input.shape()[0];
auto H = input.shape()[1];
auto W = input.shape()[2];
auto C = input.shape()[3];
auto h = kernel.shape()[0];
auto w = kernel.shape()[1];
SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
auto O = kernel.shape()[3];
SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
auto hh = result_shape[1];
auto ww = result_shape[2];
SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");
SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
int64_t sh = config.window_strides[0];
int64_t sw = config.window_strides[1];
SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);
// For 3x3 kernel size, use Winograd convolution
if (h == 3 && w == 3) {
// Parameters
auto r = h; // kernel size, which is fixed to 3 here
auto m = 2; // output tile size, which is fixed to 2 here
auto a = m + r -1; // input tile size, which is fixed to 4 here
auto T = hal::ceil((H - r + 1) / m); // number of tiles along height/weight
auto P = N * T * T; // total number of tiles
auto tile_stride = a - r + 1; // stride for extracting input tiles
// Define Winograd transformation matrices
auto G = { {1.0, 0.0, 0.0},
{0.5, 0.5, 0.5},
{0.5, -0.5, 0.5},
{0.0, 0.0, 1.0} };
auto G_T = { {1.0, 0.5, 0.5, 0.0},
{0.0, 0.5, -0.5, 0.0},
{0.0, 0.5, 0.5, 1.0} };
auto B = { {1, 0, 0, 0},
{0, 1, -1, 1},
{-1, 1, 1, 0},
{0, 0, 0, -1} };
auto B_T = { {1, 0, -1, 0},
{0, 1, 1, 0},
{0, -1, 1, 0},
{0, 1, 0, -1} };
auto A = { {1, 0},
{1, 1},
{1, -1},
{0, -1}};
auto A_T = { {1, 1, 1, 0},
{0, 1, -1, -1}};
// Transform kernel to Winograd domain
auto U = hal::matmul(ctx, G, hal::matmul(ctx, kernel, G_T));
// Transform input to Winograd domain
Value expanded;
{
std::vector<spu::Value> tiled_input;
// Separate the input into axa tiles
for (int64_t x = 0; x <= H - a; x += tile_stride) {
for (int64_t y = 0; y <= W - a; y += tile_stride) {
auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
}
}
auto stacked = hal::concatenate(ctx, tiled_input, 0);
expanded = hal::reshape(ctx, stacked, {P, a, a, C});
}
auto V = hal::matmul(ctx, B_T, hal::matmul(ctx, expanded, B));
// Transform Winograd input and kernel to GEMM format
U = hal::permute(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
V = hal::permute(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)
// Do GEMM
auto M = hal::matmul(ctx, U, V); // M (a, a, O, P)
M = hal::permute(ctx, M, {3, 0, 1, 2}); // M (a, a, O, P) -> (P, a, a, O)
// Transform Winograd output to regular output
auto output_tile = hal::matmul(ctx, A_T, hal::matmul(ctx, M, A)); // output_tile (P, 2, 2, O)
auto out_size = H - r + 1;
auto result = hal::reshape(ctx, hal::permute(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});
return result;
} else {
// Fallback, use im2col + dot to implement convolution
// expand the image according to the kernel size.
// assumption:
// - padding is erased by some compiler pass.
// - input : NxHxWxC
// - kernel : hxwxCxO
Value expanded;
{
std::vector<spu::Value> images;
for (int64_t x = 0; x <= H - h; x += sh) {
for (int64_t y = 0; y <= W - w; y += sw) {
auto window =
hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
}
}
auto stacked = hal::concatenate(ctx, images, 1);
SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
}
// TODO(jint): the below method is much slower then the code above, consider
// to use slice+reshape+concat to rewrite expandWindow.
//
// std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
// auto expanded = expandWindow(ctx, input, // input
// {N, h, w, C}, // window_shape
// {1, sh, sw, 1}, // strides
// padding);
// Now expanded shape is (N, hh*ww, h*w, C)
SPU_ENFORCE_EQ(expanded.shape()[0], N);
SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
SPU_ENFORCE_EQ(expanded.shape()[3], C);
// Reshape it to (N, hh, ww, h, w, C)
expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});
// Contract on h, w, C
// expanded: (N, hh, ww, h, w, C)
// kernel: (h, w, C, O)
// result: (N, hh, ww, O)
auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
SPU_ENFORCE_EQ(result.shape()[0], N);
SPU_ENFORCE_EQ(result.shape()[1], hh);
SPU_ENFORCE_EQ(result.shape()[2], ww);
SPU_ENFORCE_EQ(result.shape()[3], O);
return result;
}
}
} // namespace spu::kernel::hlo
Please note that I only convert 3x3 convs into Winograd version and the others are retained as regular convs. The input/output tile size are fixed to 4/2 respectively.
This implemention is a peer-to-peer implemention as its torch version but it didn't work. I have some issues:
auto
correct?permute
function (tensor.permute in torch) to transform Winograd input and kernel to GEMM format, I don't know if my use of permute
function is correct?hal::tensordot
for conv, should I use hal::tensordot
or hal::matmul
for input/kernel transformation and GEMM?SPU is a grand and complex project, so it is quite difficult for me to fully understand it. If possible, could you please provide me with the modifications required for the above code as much detailed as possible. Thank you very much.
- Winograd conv with GEMM requires the
permute
function (tensor.permute in torch) to transform Winograd input and kernel to GEMM format, I don't know if my use ofpermute
function is correct?
Hi @warpoons
Nice work! Thanks
A
B
G
, you need to build a proper Value
with constant api.permute
? Thanks~tensordot
is a more general version of dot
, which supports > 2D inputs. If you only need 2D dot, matmul
is good enough.Hadamard product
is just an element-wise multiply of two tensors right? Just use regular mul
api should be good enough.Hi @warpoons , SPU transpose
does the same thing as torch.permute.
Hi @anakinxc. Thanks for your respones!
for A B G, you need to build a proper Value with constant api.
Could you please kindly provide a specific example just about the definition of one of the Winograd transformation matrices A
, B
, G
. For example, the instence to define:
G = { {1.0, 0.0, 0.0},
{0.5, 0.5, 0.5},
{0.5, -0.5, 0.5},
{0.0, 0.0, 1.0} };
Thanks!
tensordot is a more general version of dot, which supports > 2D inputs. If you only need 2D dot, matmul is good enough.
Does this mean that tensordot
supports both matmul
and higher dimensional multiplications. Can I entirely use tensordot
instead of matmul
to avoid potential incompatibilities?
SPU
transpose
does the same thing as torch.permute.
Hi @tpppppub. Thanks for your respones!
Do you mind take a look at if the following usage of transpose
is correct? I had searched the entire spu library but did not find an instructive example of using transpose
. Thanks!
// Transform Winograd input and kernel to General Matmul format
U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)
Hi @anakinxc @tpppppub . Mind take a look at my doubts? Thanks! :)
Your understanding of transpose
is correct. Besides, the following tips are for your consideration.
Hi @tpppppub. Thanks for your quick and detailed respone! I have modified the convolution.cc
as follows:
#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/constants.h" // Newly added for building winograd transformation matrices
namespace spu::kernel::hlo {
// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
const spu::Value &kernel,
const ConvolutionConfig &config,
const Shape &result_shape) {
SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
// input : (N, H, W, C)
// kernel : (h, w, C, O)
// output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1
// Alias input dimensions.
auto N = input.shape()[0];
auto H = input.shape()[1];
auto W = input.shape()[2];
auto C = input.shape()[3];
auto h = kernel.shape()[0];
auto w = kernel.shape()[1];
SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
auto O = kernel.shape()[3];
SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
auto hh = result_shape[1];
auto ww = result_shape[2];
SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");
SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
int64_t sh = config.window_strides[0];
int64_t sw = config.window_strides[1];
SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);
// For 3x3 kernel size, use Winograd convolution
if (h == 3 && w == 3) {
// Parameters
auto r = h; // kernel size, which is fixed to 3 here
auto m = 2; // output tile size, which is fixed to 2 here
auto a = m + r -1; // input tile size, which is fixed to 4 here
auto T = hal::ceil(ctx, (H - r + 1) / m); // number of tiles along height/weight
auto P = N * T * T; // total number of tiles
auto tile_stride = a - r + 1; // stride for extracting input tiles
// Define Winograd transformation matrices
std::vector<float> G_coefficients = { 1.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0.0, 0.0, 1.0 };
std::vector<float> G_T_coefficients = { 1.0, 0.5, 0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.5, 1.0 };
std::vector<float> B_coefficients = { 1, 0, 0, 0, 0, 1, -1, 1, -1, 1, 1, 0, 0, 0, 0, -1 };
std::vector<float> B_T_coefficients = { 1, 0, -1, 0, 0, 1, 1, 0, 0, -1, 1, 0, 0, 1, 0, -1 };
std::vector<float> A_coefficients = { 1, 0, 1, 1, 1, -1, 0, -1};
std::vector<float> A_T_coefficients = { 1, 1, 1, 0, 0, 1, -1, -1};
auto G = spu::kernel::hal::constant(ctx, G_coefficients, input.dtype(), {4, 3});
auto G_T = spu::kernel::hal::constant(ctx, G_T_coefficients, input.dtype(), {3, 4});
auto B = spu::kernel::hal::constant(ctx, B_coefficients, input.dtype(), {4, 4});
auto B_T = spu::kernel::hal::constant(ctx, B_T_coefficients, input.dtype(), {4, 4});
auto A = spu::kernel::hal::constant(ctx, A_coefficients, input.dtype(), {4, 2});
auto A_T = spu::kernel::hal::constant(ctx, A_T_coefficients, input.dtype(), {2, 4});
// Transform kernel to Winograd domain
auto U = hal::matmul(ctx, G, hal::matmul(ctx, kernel, G_T));
// Transform input to Winograd domain
Value expanded;
{
std::vector<spu::Value> tiled_input;
// Separate the input into axa tiles
for (int64_t x = 0; x <= H - a; x += tile_stride) {
for (int64_t y = 0; y <= W - a; y += tile_stride) {
auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
}
}
auto stacked = hal::concatenate(ctx, tiled_input, 0);
expanded = hal::reshape(ctx, stacked, {P, a, a, C});
}
auto V = hal::matmul(ctx, B_T, hal::matmul(ctx, expanded, B));
// Transform Winograd input and kernel to General Matmul format
U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)
// Do General Matmul
auto M = hal::matmul(ctx, U, V); // M (a, a, O, P)
M = hal::transpose(ctx, M, {3, 0, 1, 2}); // M (a, a, O, P) -> (P, a, a, O)
// Transform Winograd output to regular output
auto output_tile = hal::matmul(ctx, A_T, hal::matmul(ctx, M, A)); // output_tile (P, 2, 2, O)
auto out_size = H - r + 1;
auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});
return result;
} else {
// Fallback, use im2col + dot to implement convolution
// expand the image according to the kernel size.
// assumption:
// - padding is erased by some compiler pass.
// - input : NxHxWxC
// - kernel : hxwxCxO
Value expanded;
{
std::vector<spu::Value> images;
for (int64_t x = 0; x <= H - h; x += sh) {
for (int64_t y = 0; y <= W - w; y += sw) {
auto window =
hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
}
}
auto stacked = hal::concatenate(ctx, images, 1);
SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
}
// TODO(jint): the below method is much slower then the code above, consider
// to use slice+reshape+concat to rewrite expandWindow.
//
// std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
// auto expanded = expandWindow(ctx, input, // input
// {N, h, w, C}, // window_shape
// {1, sh, sw, 1}, // strides
// padding);
// Now expanded shape is (N, hh*ww, h*w, C)
SPU_ENFORCE_EQ(expanded.shape()[0], N);
SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
SPU_ENFORCE_EQ(expanded.shape()[3], C);
// Reshape it to (N, hh, ww, h, w, C)
expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});
// Contract on h, w, C
// expanded: (N, hh, ww, h, w, C)
// kernel: (h, w, C, O)
// result: (N, hh, ww, O)
auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
SPU_ENFORCE_EQ(result.shape()[0], N);
SPU_ENFORCE_EQ(result.shape()[1], hh);
SPU_ENFORCE_EQ(result.shape()[2], ww);
SPU_ENFORCE_EQ(result.shape()[3], O);
return result;
}
}
} // namespace spu::kernel::hlo
After modifying the the convolution.cc
file, I run the command bazel run -c opt //examples/python/utils:nodectl -- --config
pwd/examples/python/ml/flax_resnet18/2pc.json up
but unexpectedly got errors like:
ERROR: /home/warpoons/Desktop/spu-0.9.1b0/libspu/kernel/hlo/BUILD.bazel:152:15: Compiling libspu/kernel/hlo/convolution.cc failed: (Exit 1): gcc failed: error executing command (from target //libspu/kernel/hlo:convolution) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 179 arguments skipped)
Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
libspu/kernel/hlo/convolution.cc: In function 'spu::Value spu::kernel::hlo::Convolution2D(spu::SPUContext*, const spu::Value&, const spu::Value&, const spu::kernel::hlo::ConvolutionConfig&, const spu::Shape&)':
libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'
62 | auto T = hal::ceil(ctx, (H - r + 1) / m); // number of tiles along height/weight
| ~~~~~~~~~~~~^~~
In file included from libspu/kernel/hlo/convolution.cc:17:
./libspu/kernel/hal/polymorphic.h:88:42: note: in passing argument 2 of 'spu::Value spu::kernel::hal::ceil(spu::SPUContext*, const spu::Value&)'
88 | Value ceil(SPUContext* ctx, const Value& in);
| ~~~~~~~~~~~~~^~
libspu/kernel/hlo/convolution.cc:124:30: error: invalid initialization of reference of type 'const spu::Shape&' from expression of type '<brace-enclosed initializer list>'
124 | expanded = hal::reshape(ctx, stacked, {P, a, a, C});
| ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from libspu/kernel/hlo/convolution.cc:18:
./libspu/kernel/hal/shape_ops.h:36:62: note: in passing argument 3 of 'spu::Value spu::kernel::hal::reshape(spu::SPUContext*, const spu::Value&, const spu::Shape&)'
36 | Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape);
| ~~~~~~~~~~~~~^~~~~~~~
libspu/kernel/hlo/convolution.cc:139:69: error: invalid initialization of reference of type 'const spu::Shape&' from expression of type '<brace-enclosed initializer list>'
139 | auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});
| ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from libspu/kernel/hlo/convolution.cc:18:
./libspu/kernel/hal/shape_ops.h:36:62: note: in passing argument 3 of 'spu::Value spu::kernel::hal::reshape(spu::SPUContext*, const spu::Value&, const spu::Shape&)'
36 | Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape);
| ~~~~~~~~~~~~~^~~~~~~~
Target //examples/python/utils:nodectl failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 116.662s, Critical Path: 47.96s
INFO: 418 processes: 15 internal, 403 linux-sandbox.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
Would you mind take a look at what went wrong there and feasible solutions? Thanks! :)
libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'
hal::ceil
accepts the type spu::Value
as argument, not builtin scalar types.
libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'
hal::ceil
accepts the typespu::Value
as argument, not builtin scalar types.
Hi @tpppppub. Thanks for pointing out that! In the modified version, I have used auto T = input.shape()[1] / 2 - 1;
without ceil and the SPU backend runtime can be successfully launched now, but when I tried to test the inference of ResNet18 with this new convolution.cc
, the errors happened:
INFO: Running command line: bazel-bin/examples/python/ml/flax_resnet18/flax_resnet_inference --config /home/warpoons/Desktop/spu-0.9.1b0/examples/python/ml/flax_resnet18/2pc.json
Run on SPU
------
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 142, in
- And the backend runtime reported something like:
[2024-07-16 17:19:27,187] [ForkServerProcess-1] Traceback (most recent call last): File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run ret_objs = fn(self, *args, **kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run rt.run(spu_exec) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run return self._vm.Run(executable.SerializeToString()) RuntimeError: what: [Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). Stacktrace:
Could you mind take a look at what were wrong here and feasible solutions? Sorry for taking your time. Thanks! :)
libspu/kernel/hlo/convolution.cc:62:41: error: invalid initialization of reference of type 'const spu::Value&' from expression of type 'long int'
hal::ceil
accepts the typespu::Value
as argument, not builtin scalar types.Hi @tpppppub. Thanks for pointing out that! In the modified version, I have used
auto T = input.shape()[1] / 2 - 1;
without ceil and the SPU backend runtime can be successfully launched now, but when I tried to test the inference of ResNet18 with this newconvolution.cc
, the errors happened:
- By running the flax_resnet_inference, I got:
INFO: Running command line: bazel-bin/examples/python/ml/flax_resnet18/flax_resnet_inference --config /home/warpoons/Desktop/spu-0.9.1b0/examples/python/ml/flax_resnet18/2pc.json Run on SPU ------ An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Traceback (most recent call last): File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 142, in <module> run_on_spu(inputs, params) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/examples/python/ml/flax_resnet18/flax_resnet_inference.py", line 118, in run_on_spu out = infer(params, inputs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in __call__ results = [future.result() for future in futures] File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 693, in <listcomp> results = [future.result() for future in futures] File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.__get_result() File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 247, in run return self._call(self._stub.Run, fn, *args, **kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/flax_resnet18/flax_resnet_inference.runfiles/spulib/spu/utils/distributed_impl.py", line 240, in _call raise Exception("remote exception", result) Exception: ('remote exception', Exception('Traceback (most recent call last):\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run\n ret_objs = fn(self, *args, **kwargs)\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run\n rt.run(spu_exec)\n File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run\n return self._vm.Run(executable.SerializeToString())\nRuntimeError: what: \n\t[Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). \nStacktrace:\n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\nstacktrace: \n#0 spu::kernel::hal::f_mmul()+0x735cd124b610\n#1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2\n#2 spu::kernel::hal::matmul()+0x735cd122df4d\n#3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee\n#4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e\n#5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3\n#6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a\n#7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4\n#8 spu::device::runBlock()+0x735cd0c66a3d\n#9 spu::device::runRegion()+0x735cd0c68ae3\n#10 spu::device::executeImpl()+0x735cd059f0e7\n#11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc\n#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d\n#13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006\n#14 cfunction_call+0x4fdc87\n\n\n'))
- And the backend runtime reported something like:
[2024-07-16 17:19:27,187] [ForkServerProcess-1] Traceback (most recent call last): File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 326, in Run ret_objs = fn(self, *args, **kwargs) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed_impl.py", line 589, in builtin_spu_run rt.run(spu_exec) File "/home/warpoons/.cache/bazel/_bazel_warpoons/580f3d68baaaad498fe83c6c3b519d10/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/api.py", line 44, in run return self._vm.Run(executable.SerializeToString()) RuntimeError: what: [Enforce fail at libspu/kernel/hal/ring.cc:230] (lhs.ndim() > 0 && lhs.ndim() <= 2). Stacktrace: #0 spu::kernel::hal::f_mmul()+0x735cd124b610 #1 spu::kernel::hal::(anonymous namespace)::dtypeBinaryDispatch<>()+0x735cd122d2e2 #2 spu::kernel::hal::matmul()+0x735cd122df4d #3 spu::kernel::hlo::Convolution2D()+0x735cd11d11ee #4 spu::device::pphlo::dispatchOp<>()+0x735cd0ae146e #5 spu::device::pphlo::dispatchOp<>()+0x735cd0ae24b3 #6 spu::device::pphlo::dispatchOp<>()+0x735cd0ae424a #7 spu::device::pphlo::dispatchOp<>()+0x735cd0ae65f4 #8 spu::device::runBlock()+0x735cd0c66a3d #9 spu::device::runRegion()+0x735cd0c68ae3 #10 spu::device::executeImpl()+0x735cd059f0e7 #11 spu::RuntimeWrapper::Run()+0x735ccf92f4cc #12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x735ccf90121d #13 pybind11::cpp_function::dispatcher()+0x735ccf8f6006 #14 cfunction_call+0x4fdc87
Could you mind take a look at what were wrong here and feasible solutions? Sorry for taking your time. Thanks! :)
matmul only supports < 2D dot. Seems you are doing 4D dot
Hi @anakinxc. Thanks for your reminder! I realized that tensordot
should be used here instead of matmul
. After several necessary modifications, the code has been successfully executed without any errors.
#include "libspu/kernel/hlo/convolution.h"
#include "libspu/core/value.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"
#include "libspu/kernel/hal/constants.h" // Newly added for building winograd transformation matrices
namespace spu::kernel::hlo {
// This is an optimized conv2D with im2col
spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input,
const spu::Value &kernel,
const ConvolutionConfig &config,
const Shape &result_shape) {
SPU_ENFORCE(!input.isComplex() && !kernel.isComplex());
// input : (N, H, W, C)
// kernel : (h, w, C, O)
// output : (N, hh,ww,O), where hh=(H-h)/sh+1, ww=(W-w)/sw+1
// Alias input dimensions.
auto N = input.shape()[0];
auto H = input.shape()[1];
auto W = input.shape()[2];
auto C = input.shape()[3];
auto h = kernel.shape()[0];
auto w = kernel.shape()[1];
SPU_ENFORCE_EQ(kernel.shape()[2], C, "input/kernel channel mismatch");
auto O = kernel.shape()[3];
SPU_ENFORCE_EQ(result_shape[0], N, "result batch mismatch");
auto hh = result_shape[1];
auto ww = result_shape[2];
SPU_ENFORCE_EQ(result_shape[3], O, "result filters mismatch");
SPU_ENFORCE_EQ(config.window_strides.size(), 2U);
int64_t sh = config.window_strides[0];
int64_t sw = config.window_strides[1];
SPU_ENFORCE_EQ(hh, (H - h) / sh + 1);
SPU_ENFORCE_EQ(ww, (W - w) / sw + 1);
// For 3x3 kernel size with 1 stride, use Winograd convolution
if (h == 3 && w == 3 && config.window_strides[0] == 1 && config.window_strides[1] == 1) {
// Parameters
auto r = kernel.shape()[0]; // kernel size, which is fixed to 3 here
auto m = 2; // output tile size, which is fixed to 2 here
auto a = m + r -1; // input tile size, which is fixed to 4 here
auto T = input.shape()[1] / 2 - 1;
auto P = N * T * T; // total number of tiles
auto tile_stride = a - r + 1; // stride for extracting input tiles
// Define Winograd transformation matrices
std::vector<float> G_coefficients = { 1.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0.0, 0.0, 1.0 };
std::vector<float> G_T_coefficients = { 1.0, 0.5, 0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.5, 1.0 };
std::vector<float> B_coefficients = { 1, 0, 0, 0, 0, 1, -1, 1, -1, 1, 1, 0, 0, 0, 0, -1 };
std::vector<float> B_T_coefficients = { 1, 0, -1, 0, 0, 1, 1, 0, 0, -1, 1, 0, 0, 1, 0, -1 };
std::vector<float> A_coefficients = { 1, 0, 1, 1, 1, -1, 0, -1};
std::vector<float> A_T_coefficients = { 1, 1, 1, 0, 0, 1, -1, -1};
auto G = spu::kernel::hal::constant(ctx, G_coefficients, input.dtype(), {4, 3});
auto G_T = spu::kernel::hal::constant(ctx, G_T_coefficients, input.dtype(), {3, 4});
auto B = spu::kernel::hal::constant(ctx, B_coefficients, input.dtype(), {4, 4});
auto B_T = spu::kernel::hal::constant(ctx, B_T_coefficients, input.dtype(), {4, 4});
auto A = spu::kernel::hal::constant(ctx, A_coefficients, input.dtype(), {4, 2});
auto A_T = spu::kernel::hal::constant(ctx, A_T_coefficients, input.dtype(), {2, 4});
// Transform kernel to Winograd domain Wrong here, try tensordot
auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0});
U = hal::transpose(ctx, U, {0, 3, 1, 2});
// Transform input to Winograd domain
Value expanded;
{
std::vector<spu::Value> tiled_input;
// Separate the input into axa tiles
for (int64_t x = 0; x <= H - a; x += tile_stride) {
for (int64_t y = 0; y <= W - a; y += tile_stride) {
auto tile = hal::slice(ctx, input, {0, x, y, 0}, {N, x + a, y + a, C}, {});
tiled_input.emplace_back(hal::reshape(ctx, tile, {1, N, a, a, C}));
}
}
auto stacked = hal::concatenate(ctx, tiled_input, 0);
expanded = hal::reshape(ctx, stacked, {P, a, a, C});
}
auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1});
V = hal::transpose(ctx, V, {1, 0, 3, 2});
// Transform Winograd input and kernel to General Matmul format
U = hal::transpose(ctx, U, {0, 1, 3, 2}); // U (a, a, C, O) -> (a, a, O, C)
V = hal::transpose(ctx, V, {1, 2, 3, 0}); // V (P, a, a, C) -> (a, a, C, P)
// Do General Matmul
auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2});
M = hal::transpose(ctx, M, {3, 0, 2, 1}); // M -> (P, a, a, O)
// Transform Winograd output to regular output
auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // output_tile (P, 2, 2, O)
output_tile = hal::transpose(ctx, output_tile, {2, 1, 0, 3});
auto out_size = H - r + 1;
auto result = hal::reshape(ctx, hal::transpose(ctx, hal::reshape(ctx, output_tile, {N, T, T, m, m, O}), {0, 1, 3, 2, 4, 5}), {N, out_size, out_size, O});
return result;
} else {
// Fallback, use im2col + dot to implement convolution
// expand the image according to the kernel size.
// assumption:
// - padding is erased by some compiler pass.
// - input : NxHxWxC
// - kernel : hxwxCxO
Value expanded;
{
std::vector<spu::Value> images;
for (int64_t x = 0; x <= H - h; x += sh) {
for (int64_t y = 0; y <= W - w; y += sw) {
auto window =
hal::slice(ctx, input, {0, x, y, 0}, {N, x + h, y + w, C}, {});
images.emplace_back(hal::reshape(ctx, window, {N, 1, h, w, C}));
}
}
auto stacked = hal::concatenate(ctx, images, 1);
SPU_ENFORCE_EQ(stacked.shape()[1], hh * ww);
expanded = hal::reshape(ctx, stacked, {N, hh * ww, h * w, C});
}
// TODO(jint): the below method is much slower then the code above, consider
// to use slice+reshape+concat to rewrite expandWindow.
//
// std::vector<std::pair<int64_t, int64_t>> padding(4, {0, 0});
// auto expanded = expandWindow(ctx, input, // input
// {N, h, w, C}, // window_shape
// {1, sh, sw, 1}, // strides
// padding);
// Now expanded shape is (N, hh*ww, h*w, C)
SPU_ENFORCE_EQ(expanded.shape()[0], N);
SPU_ENFORCE_EQ(expanded.shape()[1], hh * ww);
SPU_ENFORCE_EQ(expanded.shape()[2], h * w);
SPU_ENFORCE_EQ(expanded.shape()[3], C);
// Reshape it to (N, hh, ww, h, w, C)
expanded = hal::reshape(ctx, expanded, {N, hh, ww, h, w, C});
// Contract on h, w, C
// expanded: (N, hh, ww, h, w, C)
// kernel: (h, w, C, O)
// result: (N, hh, ww, O)
// std::cout << "The shape of expanded is: " << expanded.shape() << std::endl;
// std::cout << "The shape of kernel is: " << kernel.shape() << std::endl;
auto result = hal::tensordot(ctx, expanded, kernel, {3, 4, 5}, {0, 1, 2});
std::cout << "The shape of std conv result is: " << result.shape() << std::endl;
SPU_ENFORCE_EQ(result.shape()[0], N);
SPU_ENFORCE_EQ(result.shape()[1], hh);
SPU_ENFORCE_EQ(result.shape()[2], ww);
SPU_ENFORCE_EQ(result.shape()[3], O);
return result;
}
}
} // namespace spu::kernel::hlo
But the profile of ResNet18 on CIFAR-100 was still wired:
[2024-07-18 17:18:03.364] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.6291e-05s, execution took 0.442374948s, output processing took 3.8402e-05s, total time 0.442429641s.
[2024-07-18 17:18:03.364] [info] [api.cc:211] HLO profiling: total time 0.43840665800000006
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.173768735s, send bytes 47269376 recv bytes 47269376
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.089064589s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.06104458s, send bytes 2906624 recv bytes 2906624
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.060172622s, send bytes 3673792 recv bytes 3673792
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.038431921s, send bytes 2392064 recv bytes 2392064
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.005943561s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003742128s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002317673s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.001958798s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001353351s, send bytes 414496 recv bytes 414496
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000157907s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000149092s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.364] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000109685s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.4707e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.5505e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 3.7019e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4785e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:211] HAL profiling: total time 0.39880832400000005
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_tensordot, executed 20 times, duration 0.155717562s, send bytes 47269376 recv bytes 47269376
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.089025885s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_less, executed 41 times, duration 0.081797061s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046109298s, send bytes 2791424 recv bytes 2791424
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.013188393s, send bytes 882368 recv bytes 882368
[2024-07-18 17:18:03.365] [info] [api.cc:214] - _mux, executed 4 times, duration 0.005991638s, send bytes 557056 recv bytes 557056
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003328688s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002271631s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.00134875s, send bytes 414496 recv bytes 414496
[2024-07-18 17:18:03.365] [info] [api.cc:214] - seal, executed 2 times, duration 2.9418e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:211] MPC profiling: total time 0.4245964730000001
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.12947814s, send bytes 47277568 recv bytes 47277568
[2024-07-18 17:18:03.365] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.07886524s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:18:03.365] [info] [api.cc:214] - trunc_a, executed 326 times, duration 0.068162238s, send bytes 1693472 recv bytes 1693472
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.047058035s, send bytes 3550208 recv bytes 3550208
[2024-07-18 17:18:03.365] [info] [api.cc:214] - a2b, executed 20 times, duration 0.021607477s, send bytes 998400 recv bytes 998400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.018228241s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018212783s, send bytes 147072 recv bytes 147072
[2024-07-18 17:18:03.365] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.014937339s, send bytes 422400 recv bytes 422400
[2024-07-18 17:18:03.365] [info] [api.cc:214] - pad, executed 15 times, duration 0.005907683s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004702058s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - extract_slice, executed 4968 times, duration 0.004617872s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - not_a, executed 85 times, duration 0.003940147s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - reshape, executed 4936 times, duration 0.002173678s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001375718s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.00123122s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000840037s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000659155s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000602192s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000551214s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - transpose, executed 91 times, duration 0.000488421s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - make_p, executed 515 times, duration 0.000357092s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000202838s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000112163s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.4106e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.2188e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.1299e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - not_p, executed 20 times, duration 4.1154e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - p2a, executed 2 times, duration 2.6339e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0406e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:18:03.365] [info] [api.cc:224] Link details: total send bytes 58830752, recv bytes 58830752, send actions 1231
[2024-07-18 17:19:30.001] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.638e-05s, execution took 0.560124821s, output processing took 4.0183e-05s, total time 0.560181384s.
[2024-07-18 17:19:30.001] [info] [api.cc:211] HLO profiling: total time 0.5566890019999999
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.288653808s, send bytes 67241472 recv bytes 67241472
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.087015893s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.061245766s, send bytes 2906624 recv bytes 2906624
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.061008014s, send bytes 3673792 recv bytes 3673792
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.039093234s, send bytes 2392064 recv bytes 2392064
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.00897712s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003606019s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.003026938s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.001974644s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001423771s, send bytes 414496 recv bytes 414496
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000163294s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000151001s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000150943s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 8.0588e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.3774e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 3.5771e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.8424e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:211] HAL profiling: total time 0.5165712140000001
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_tensordot, executed 62 times, duration 0.273441432s, send bytes 67241472 recv bytes 67241472
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.086978077s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_less, executed 41 times, duration 0.08223788s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046790546s, send bytes 2791424 recv bytes 2791424
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.013281779s, send bytes 882368 recv bytes 882368
[2024-07-18 17:19:30.001] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006172611s, send bytes 557056 recv bytes 557056
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003243557s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002978843s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001418869s, send bytes 414496 recv bytes 414496
[2024-07-18 17:19:30.001] [info] [api.cc:214] - seal, executed 2 times, duration 2.762e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:211] MPC profiling: total time 0.542528678
[2024-07-18 17:19:30.001] [info] [api.cc:214] - trunc_a, executed 368 times, duration 0.158244564s, send bytes 18896672 recv bytes 18896672
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.145741839s, send bytes 50046464 recv bytes 50046464
[2024-07-18 17:19:30.001] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.079268476s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.048565551s, send bytes 3550208 recv bytes 3550208
[2024-07-18 17:19:30.001] [info] [api.cc:214] - a2b, executed 20 times, duration 0.021323257s, send bytes 998400 recv bytes 998400
[2024-07-18 17:19:30.001] [info] [api.cc:214] - b2a, executed 101 times, duration 0.017856445s, send bytes 147072 recv bytes 147072
[2024-07-18 17:19:30.001] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.015862148s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.014653324s, send bytes 422400 recv bytes 422400
[2024-07-18 17:19:30.001] [info] [api.cc:214] - pad, executed 15 times, duration 0.00893612s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - reshape, executed 4841 times, duration 0.007305119s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - not_a, executed 85 times, duration 0.004666616s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004651312s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.004300958s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mmul_ap, executed 42 times, duration 0.003888554s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001343795s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001247496s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - transpose, executed 287 times, duration 0.001069815s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000835682s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.00064363s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000587136s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000521721s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - make_p, executed 515 times, duration 0.000382466s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000231394s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000116977s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.2533e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.4375e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.0669e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - not_p, executed 20 times, duration 4.135e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - p2a, executed 2 times, duration 2.3309e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:214] - reverse, executed 11 times, duration 1.2047e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:19:30.001] [info] [api.cc:224] Link details: total send bytes 78802848, recv bytes 78802848, send actions 1273
It can be noted that the comms of mmul_aa
and trunc_a
are increased and the total comm is also increased a lot, which violates the expected multi reduction in Winograd conv.
Additionally, I noted that in the Winograd conv, three hal::tensordot are used where two for the input/kernel transformation and one for the GEMM, where the regular only use hal::tensordot once.
Remember that for a fixed model with known weights, the kernel transformation can be done offline. Therefore, I replace the hal::tensordot in kernel transformation with constant
all-ones mat to simulate this situation. And the profile is:
[2024-07-18 17:29:36.238] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.7796e-05s, execution took 0.438780767s, output processing took 3.7401e-05s, total time 0.438835964s.
[2024-07-18 17:29:36.238] [info] [api.cc:211] HLO profiling: total time 0.4354281060000001
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.172220784s, send bytes 43451904 recv bytes 43451904
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.089114364s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.060312438s, send bytes 2906624 recv bytes 2906624
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.058106095s, send bytes 3673792 recv bytes 3673792
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.038393581s, send bytes 2392064 recv bytes 2392064
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.007641689s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003542044s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002279123s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.001889133s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001350285s, send bytes 414496 recv bytes 414496
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000147391s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000138275s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000107958s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.3047e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.1171e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 3.5468e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.526e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:211] HAL profiling: total time 0.39483937999999996
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_tensordot, executed 48 times, duration 0.154494692s, send bytes 43451904 recv bytes 43451904
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.089073318s, send bytes 2174400 recv bytes 2174400
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_less, executed 41 times, duration 0.081052562s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.044366134s, send bytes 2791424 recv bytes 2791424
[2024-07-18 17:29:36.238] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.012888425s, send bytes 882368 recv bytes 882368
[2024-07-18 17:29:36.238] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006011205s, send bytes 557056 recv bytes 557056
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003353158s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002226449s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001345725s, send bytes 414496 recv bytes 414496
[2024-07-18 17:29:36.238] [info] [api.cc:214] - seal, executed 2 times, duration 2.7712e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.238] [info] [api.cc:211] MPC profiling: total time 0.421531365
[2024-07-18 17:29:36.238] [info] [api.cc:214] - mmul_aa, executed 14 times, duration 0.107591987s, send bytes 40936960 recv bytes 40936960
[2024-07-18 17:29:36.238] [info] [api.cc:214] - trunc_a, executed 354 times, duration 0.083716544s, send bytes 4216608 recv bytes 4216608
[2024-07-18 17:29:36.239] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.078090014s, send bytes 4741632 recv bytes 4741632
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.045529196s, send bytes 3550208 recv bytes 3550208
[2024-07-18 17:29:36.239] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022105726s, send bytes 998400 recv bytes 998400
[2024-07-18 17:29:36.239] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018533551s, send bytes 147072 recv bytes 147072
[2024-07-18 17:29:36.239] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.018099175s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.01441098s, send bytes 422400 recv bytes 422400
[2024-07-18 17:29:36.239] [info] [api.cc:214] - pad, executed 15 times, duration 0.007607729s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004777058s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.004058945s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - not_a, executed 85 times, duration 0.003906662s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - reshape, executed 4799 times, duration 0.003217838s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mmul_ap, executed 35 times, duration 0.002834911s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001369924s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001247104s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - transpose, executed 252 times, duration 0.000901744s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000813418s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000654728s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000631001s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000514174s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - make_p, executed 515 times, duration 0.000340834s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000188551s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000107943s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.318e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.9109e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - mul_pp, executed 20 times, duration 5.5609e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - not_p, executed 20 times, duration 3.9903e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - p2a, executed 2 times, duration 2.3822e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0005e-05s, send bytes 0 recv bytes 0
[2024-07-18 17:29:36.239] [info] [api.cc:224] Link details: total send bytes 55013280, recv bytes 55013280, send actions 1252
The total comm is slightly decreased but still far from expectations. So confused!
Would you mind take a look at this and is there any suggestions? BTW, I used 2PC.json with "protocol": "SEMI2K", "field": "FM64" in this test, Thanks! :)
Hi @warpoons
Please try to benchmark with experimental_disable_mmul_split
set to True
.
Looks like your implementation requires 7 tensordot
where the original implementation requires only 1.
Tensordot with non-integer inputs will introduce a tru
ncation operation which is expensive and needs communication.
One thing you can try is reduce number of truncation
, for example, trunc(dot(x, y) + trunc(dot(y, z))
can be simplified into trunc(dot(x,y) + dot(y,z))
Hi @warpoons
Please try to benchmark with
experimental_disable_mmul_split
set toTrue
.Looks like your implementation requires 7
tensordot
where the original implementation requires only 1.Tensordot with non-integer inputs will introduce a
tru
ncation operation which is expensive and needs communication.One thing you can try is reduce number of
truncation
, for example,trunc(dot(x, y) + trunc(dot(y, z))
can be simplified intotrunc(dot(x,y) + dot(y,z))
Hi @anakinx. Thanks for your suggestion! I have benchmarked with "experimental_disable_mmul_split": true
in 2pc.json. And the profile of Winograd conv is:
[2024-07-20 09:58:35.945] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.5933e-05s, execution took 0.580790941s, output processing took 3.9173e-05s, total time 0.580846047s.
[2024-07-20 09:58:35.946] [info] [api.cc:211] HLO profiling: total time 0.5768848610000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.302007006s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.091915749s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.064329796s, send bytes 2906624 recv bytes 2906624
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.062256372s, send bytes 3673792 recv bytes 3673792
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.039781986s, send bytes 2392064 recv bytes 2392064
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.006391185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003279638s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002689619s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.002069318s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001498637s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000175185s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000169835s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000117636s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.9414e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.7852e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 4.0959e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4674e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] HAL profiling: total time 0.536832459
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_tensordot, executed 62 times, duration 0.284671008s, send bytes 67241472 recv bytes 67241472
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.091869937s, send bytes 2174400 recv bytes 2174400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_less, executed 41 times, duration 0.085656933s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046768409s, send bytes 2791424 recv bytes 2791424
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.01449265s, send bytes 882368 recv bytes 882368
[2024-07-20 09:58:35.946] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006170996s, send bytes 557056 recv bytes 557056
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003050888s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002625307s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001493948s, send bytes 414496 recv bytes 414496
[2024-07-20 09:58:35.946] [info] [api.cc:214] - seal, executed 2 times, duration 3.2383e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:211] MPC profiling: total time 0.5615995250000001
[2024-07-20 09:58:35.946] [info] [api.cc:214] - trunc_a, executed 368 times, duration 0.165186115s, send bytes 18896672 recv bytes 18896672
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.152509485s, send bytes 50046464 recv bytes 50046464
[2024-07-20 09:58:35.946] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.082573462s, send bytes 4741632 recv bytes 4741632
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.04785361s, send bytes 3550208 recv bytes 3550208
[2024-07-20 09:58:35.946] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022759137s, send bytes 998400 recv bytes 998400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018657123s, send bytes 147072 recv bytes 147072
[2024-07-20 09:58:35.946] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.017694911s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.015000716s, send bytes 422400 recv bytes 422400
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reshape, executed 4841 times, duration 0.007828969s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - pad, executed 15 times, duration 0.006348377s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004482565s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.00437641s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_a, executed 85 times, duration 0.004254609s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_ap, executed 42 times, duration 0.004156978s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001468896s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001353077s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - transpose, executed 287 times, duration 0.001180328s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000877519s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000723871s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000653156s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000577432s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - make_p, executed 515 times, duration 0.0004292s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000221643s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000124417s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.8349e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.7871e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.7668e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - not_p, executed 20 times, duration 4.5916e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - p2a, executed 2 times, duration 2.7477e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0238e-05s, send bytes 0 recv bytes 0
[2024-07-20 09:58:35.946] [info] [api.cc:224] Link details: total send bytes 78802848, recv bytes 78802848, send actions 1273
Seem to be no difference here.
And for the number of tensordot
, there are indeed 7 tensordot in a Winograd conv, where they specifically are:
For 6 fixed pre-defined 2D transformation matrices: G = { {1.0, 0.0, 0.0}, {0.5, 0.5, 0.5}, {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0} }; G_T = { {1.0, 0.5, 0.5, 0.0}, {0.0, 0.5, -0.5, 0.0}, {0.0, 0.5, 0.5, 1.0} }; B = { {1, 0, 0, 0}, {0, 1, -1, 1}, {-1, 1, 1, 0}, {0, 0, 0, -1} }; B_T = { {1, 0, -1, 0}, {0, 1, 1, 0}, {0, -1, 1, 0}, {0, 1, 0, -1} }; A = { {1, 0}, {1, 1}, {1, -1}, {0, -1}}; A_T = { {1, 1, 1, 0}, {0, 1, -1, -1}};
// Transform 4D (h, w, inCh, OutCh) kernel to Winograd domain auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0}); // 2 tensordot here // Transform 4D (N, H, W, inCh) input to Winograd domain auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1}); // 2 tensordot here // Do General Matmul auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2}); // 1 tensordot here // Transform Winograd output back to regular output auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // 2 tensordot here
Therefore, there are total 2+2+1+2=7 tensordot for one Winograd conv. This might be the reason for high trunc and total comm.
I have also merged all the tensordot into one expression but the profile didn't change.
Are there any feasible optimizations? Thanks! :)
Hi @warpoons Please try to benchmark with
experimental_disable_mmul_split
set toTrue
. Looks like your implementation requires 7tensordot
where the original implementation requires only 1. Tensordot with non-integer inputs will introduce atru
ncation operation which is expensive and needs communication. One thing you can try is reduce number oftruncation
, for example,trunc(dot(x, y) + trunc(dot(y, z))
can be simplified intotrunc(dot(x,y) + dot(y,z))
Hi @AnakinX. Thanks for your suggestion! I have benchmarked with
"experimental_disable_mmul_split": true
in 2pc.json. And the profile of Winograd conv is:[2024-07-20 09:58:35.945] [info] [api.cc:165] [Profiling] SPU execution infer completed, input processing took 1.5933e-05s, execution took 0.580790941s, output processing took 3.9173e-05s, total time 0.580846047s. [2024-07-20 09:58:35.946] [info] [api.cc:211] HLO profiling: total time 0.5768848610000001 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convolution, executed 20 times, duration 0.302007006s, send bytes 67241472 recv bytes 67241472 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.rsqrt, executed 20 times, duration 0.091915749s, send bytes 2174400 recv bytes 2174400 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.less, executed 37 times, duration 0.064329796s, send bytes 2906624 recv bytes 2906624 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.multiply, executed 222 times, duration 0.062256372s, send bytes 3673792 recv bytes 3673792 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce_window, executed 1 times, duration 0.039781986s, send bytes 2392064 recv bytes 2392064 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.pad, executed 15 times, duration 0.006391185s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reduce, executed 30 times, duration 0.003279638s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.negate, executed 40 times, duration 0.002689619s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.add, executed 129 times, duration 0.002069318s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.dot, executed 1 times, duration 0.001498637s, send bytes 414496 recv bytes 414496 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.free, executed 612 times, duration 0.000175185s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.broadcast, executed 45 times, duration 0.000169835s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.transpose, executed 21 times, duration 0.000117636s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reshape, executed 29 times, duration 7.9414e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.constant, executed 30 times, duration 5.7852e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.convert, executed 2 times, duration 4.0959e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pphlo.reverse, executed 11 times, duration 2.4674e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:211] HAL profiling: total time 0.536832459 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_tensordot, executed 62 times, duration 0.284671008s, send bytes 67241472 recv bytes 67241472 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_rsqrt, executed 20 times, duration 0.091869937s, send bytes 2174400 recv bytes 2174400 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_less, executed 41 times, duration 0.085656933s, send bytes 4741632 recv bytes 4741632 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mul, executed 185 times, duration 0.046768409s, send bytes 2791424 recv bytes 2791424 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mixed_mul, executed 37 times, duration 0.01449265s, send bytes 882368 recv bytes 882368 [2024-07-20 09:58:35.946] [info] [api.cc:214] - _mux, executed 4 times, duration 0.006170996s, send bytes 557056 recv bytes 557056 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_add, executed 283 times, duration 0.003050888s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_negate, executed 40 times, duration 0.002625307s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - f_mmul, executed 1 times, duration 0.001493948s, send bytes 414496 recv bytes 414496 [2024-07-20 09:58:35.946] [info] [api.cc:214] - seal, executed 2 times, duration 3.2383e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:211] MPC profiling: total time 0.5615995250000001 [2024-07-20 09:58:35.946] [info] [api.cc:214] - trunc_a, executed 368 times, duration 0.165186115s, send bytes 18896672 recv bytes 18896672 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_aa, executed 21 times, duration 0.152509485s, send bytes 50046464 recv bytes 50046464 [2024-07-20 09:58:35.946] [info] [api.cc:214] - msb_a2b, executed 41 times, duration 0.082573462s, send bytes 4741632 recv bytes 4741632 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_aa, executed 236 times, duration 0.04785361s, send bytes 3550208 recv bytes 3550208 [2024-07-20 09:58:35.946] [info] [api.cc:214] - a2b, executed 20 times, duration 0.022759137s, send bytes 998400 recv bytes 998400 [2024-07-20 09:58:35.946] [info] [api.cc:214] - b2a, executed 101 times, duration 0.018657123s, send bytes 147072 recv bytes 147072 [2024-07-20 09:58:35.946] [info] [api.cc:214] - concatenate, executed 16 times, duration 0.017694911s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bb, executed 120 times, duration 0.015000716s, send bytes 422400 recv bytes 422400 [2024-07-20 09:58:35.946] [info] [api.cc:214] - reshape, executed 4841 times, duration 0.007828969s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - pad, executed 15 times, duration 0.006348377s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - add_aa, executed 335 times, duration 0.004482565s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - extract_slice, executed 4740 times, duration 0.00437641s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - not_a, executed 85 times, duration 0.004254609s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mmul_ap, executed 42 times, duration 0.004156978s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bb, executed 580 times, duration 0.001468896s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - add_ap, executed 202 times, duration 0.001353077s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - transpose, executed 287 times, duration 0.001180328s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - bitrev_b, executed 40 times, duration 0.000877519s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - and_bp, executed 360 times, duration 0.000723871s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - rshift_b, executed 360 times, duration 0.000653156s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_ap, executed 210 times, duration 0.000577432s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - make_p, executed 515 times, duration 0.0004292s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - lshift_b, executed 120 times, duration 0.000221643s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - broadcast, executed 103 times, duration 0.000124417s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - add_pp, executed 40 times, duration 8.8349e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - xor_bp, executed 20 times, duration 6.7871e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - mul_pp, executed 20 times, duration 6.7668e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - not_p, executed 20 times, duration 4.5916e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - p2a, executed 2 times, duration 2.7477e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:214] - reverse, executed 11 times, duration 1.0238e-05s, send bytes 0 recv bytes 0 [2024-07-20 09:58:35.946] [info] [api.cc:224] Link details: total send bytes 78802848, recv bytes 78802848, send actions 1273
Seem to be no difference here.
And for the number of
tensordot
, there are indeed 7 tensordot in a Winograd conv, where they specifically are:For 6 fixed pre-defined 2D transformation matrices: G = { {1.0, 0.0, 0.0}, {0.5, 0.5, 0.5}, {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0} }; G_T = { {1.0, 0.5, 0.5, 0.0}, {0.0, 0.5, -0.5, 0.0}, {0.0, 0.5, 0.5, 1.0} }; B = { {1, 0, 0, 0}, {0, 1, -1, 1}, {-1, 1, 1, 0}, {0, 0, 0, -1} }; B_T = { {1, 0, -1, 0}, {0, 1, 1, 0}, {0, -1, 1, 0}, {0, 1, 0, -1} }; A = { {1, 0}, {1, 1}, {1, -1}, {0, -1}}; A_T = { {1, 1, 1, 0}, {0, 1, -1, -1}}; // Transform 4D (h, w, inCh, OutCh) kernel to Winograd domain auto U = hal::tensordot(ctx, G, hal::tensordot(ctx, kernel, G_T, {1}, {0}), {1}, {0}); // 2 tensordot here // Transform 4D (N, H, W, inCh) input to Winograd domain auto V = hal::tensordot(ctx, B_T, hal::tensordot(ctx, expanded, B, {2}, {0}), {1}, {1}); // 2 tensordot here // Do General Matmul auto M = hal::tensordot(ctx, U, V, {1, 3}, {0, 2}); // 1 tensordot here // Transform Winograd output back to regular output auto output_tile = hal::tensordot(ctx, A_T, hal::tensordot(ctx, M, A, {2}, {0}), {1}, {1}); // 2 tensordot here
Therefore, there are total 2+2+1+2=7 tensordot for one Winograd conv. This might be the reason for high trunc and total comm.
I have also merged all the tensordot into one expression but the profile didn't change.
Are there any feasible optimizations? Thanks! :)
Merge into one expression does not reduce number of calls.
Try to figure out how many truncations you actually need might be a better idea.
Thank you @anakinxc. As you previously said, matmul is suitable for 2D dot while tensorsot supports >2D dot. An additional question is that does matmul also introduce 1 trunc just like tensorsot?
Thanks for the patient guidance and discussion. My problem has been solved. This issue can be closed as completed. Thansk for all!
Issue Type
Build/Install
Modules Involved
MPC protocol
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu 0.9.0.dev20240311
OS Platform and Distribution
Ubuntu 18.04.6 LTS by WSL
Python Version
3.10
Compiler Version
GCC 11.3.0
Current Behavior?
Hi dear SPU team,
In recent weeks I have been working on using SPU to evaluate the private inference cost of Winograd-based ConvNets, which can reduce the number of multiplications and thus a lightweight comm cost.
BTW, using Winograd for more efficient PPML is gradually emerging and has been proven to be an efficient way to improve the comm efficiency, including:
But I found that the reduction of multiplications will NOT make the comm truly decreases in SPU. In general, the respective comms of a single conv layer are: the stand conv: 759296 byte, Winograd-EWMM (element-wise matmul) conv: 6291456 byte, Winograd-GEMM (general matmul) conv: 3964928 byte. Apparently, the Winograd algorithm increases the comm by 8.28x and 5.22x, respectively. And the increased comm mainly comes from the frequent invocation of truncations. For details please see here.
So if now I want to modify the C++ backend op for Winograd in SPU, could you tell me where should I start and which piece should i start looking at in SPU? Thanks! (If possible, I am willing to contribute the support for Winograd to the SPU community.)
Standalone code to reproduce the issue
Relevant log output