taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.45k stars 2.28k forks source link

Advanced optimization #656

Closed xumingkuan closed 2 years ago

xumingkuan commented 4 years ago

Concisely describe the proposed feature With new extensions introduced by #581, there are lots of space to optimize the IR. I also found some feasible optimizations that are not directly related to the new extension. For example, in this fragment of IR,

...
<f32 x1> $5 = alloca
if $26 {
  ...
} else {
  ...
}
if $26 {
  ...
} else {
  ...
}
<f32 x1> $83 = local load [ [$5[0]]] (the only statement about $5)
...

we could merge the two if's together, change $83 to const [0], and then delete $5.

A list of optimizations I have done and going to do:

Additional comments For benchmarking, we may want to introduce a temporary boolean variable as the switch of optimization.

Some nice slides: https://courses.cs.washington.edu/courses/cse401/08wi/lecture/opt-mark.v2.pdf

xumingkuan commented 4 years ago

@yuanming-hu please assign me. It seems that I can't assign myself...

yuanming-hu commented 4 years ago

Awesome!! This is vitally important for improving run-time performance & reducing compilation time. Thanks for taking charge of this.

archibate commented 4 years ago

Merge adjacent if's with identical condition

What if these if's contains statements with side-effect like x = x + 1? eg.

if (cond) x++; if (cond) x++;

We want to obtain:

if (cond) { x++; x++; }

and the duplicated x++ can be dealt in other lower passes.

Merge identical local loads if no statements between them modify the variable even if there are if's

What if the two local load is in different blocks? eg.

if (cond) { print 'yes'; x = local load 233; } else { print 'no'; x = local load 233; }

What if a statement is shown once in IR, but ran for multiple times, should we optimize it? eg.

while (cond) { x = local load 233 ... (no changes stored to 233) }

We may move this out the while.

First add a analysis pass to detect if a block stored an address.

xumingkuan commented 4 years ago

Merge adjacent if's with identical condition

What if these if's contains statements with side-effect like x = x + 1? eg.

if (cond) x++; if (cond) x++;

We want to obtain:

if (cond) { x++; x++; }

and the duplicated x++ can be dealt in other lower passes.

Exactly.

Merge identical local loads if no statements between them modify the variable even if there are if's

What if the two local load is in different blocks? eg.

if (cond) { print 'yes'; x = local load 233; } else { print 'no'; x = local load 233; }

This is non-trivial. We could analyze the common code fragment of true-branch and the false-branch, and put them outside the if, but I don't know if it would make a great difference.

What if a statement is shown once in IR, but ran for multiple times, should we optimize it? eg.

while (cond) { x = local load 233 ... (no changes stored to 233) }

We may move this out the while.

If cond is false, does moving it out have side effects?

First add a analysis pass to detect if a block stored an address.

To merge identical local loads if no statements between them modify the variable, this is not necessary: I think directly searching for modifications when we find a local load fits the code frame better. Maybe we can add this pass later if necessary.

archibate commented 4 years ago

If cond is false, does moving it out have side effects?

No, it's just load and never used, will be opt-out by other lower passes.

archibate commented 4 years ago

How about first make:

if (cond) { print 'yes'; x = local load 233; } else { print 'no'; x = local load 233; }

to become:

if (cond) print 'yes'; else print 'no'; if (cond) xxx; else xxx;

since cond is aconstant IR value, and the second can be safely opt-out.

xumingkuan commented 4 years ago

How about first make:

if (cond) { print 'yes'; x = local load 233; } else { print 'no'; x = local load 233; }

to become:

if (cond) print 'yes'; else print 'no'; if (cond) xxx; else xxx;

since cond is aconstant IR value, and the second can be safely opt-out.

I just thought about a situation:

if (cond) {
  print 'yes';
  x = local load 233;
  print 'yes';
} else {
  print 'no';
  x = local load 233;
  print 'no';
}

I can't tell if the following is more efficient than the above:

if (cond) print 'yes'; else print 'no';
x = local load 233;
if (cond) print 'yes'; else print 'no';

(especially when the common code fragment is relatively short than the others)

We can restrict this optimization to only the first statement and the last statement of the body of if.

xumingkuan commented 4 years ago

@yuanming-hu What do https://github.com/taichi-dev/taichi/blob/aa90e319be3b599085495f88b660f4e987a08134/taichi/ir/ir.h#L1637 mean?

May I just ignore them when merging two adjacent if's?

yuanming-hu commented 4 years ago

Quick answer for now: yes. I'll document this in greater detail later. You don't have to worry about that until we start doing vectorization.

xumingkuan commented 4 years ago

I just found a piece of IR:

<i32 x1> $8 = const [0]
...
if $19 {
  ...
  <i32 x1> $25 = const [0]
  ...
} else {
  ...
  <i32 x1> $40 = const [0]
  ...
}

I think we could optimize them all to $8. Currently void visit(ConstStmt*) searches statements before the current statement, and so $25 cannot find $8 as they are not in a basic block.

There are two ways to do this optimization:

  1. Search statements after the current statement (say $8) instead, and dive into container statements to replace them with $8.
  2. Search statements before the current statement (say $25), and do this recursively for parent blocks.

Which do you think is better?

yuanming-hu commented 4 years ago

I think 2 is better. At compile time it's hard to judge whether $25 or $40 will be after $8, but it's sure that $8 is before $25 and $40.

xumingkuan commented 4 years ago

Shall this pass (identical ConstStmt elimination) be still in BasicBlockSimplify? It won't be in one basic block, so maybe I should implement it in Simplify?

yuanming-hu commented 4 years ago

Let's add a WholeKernelCSE (common subexpression elimination) pass then.

xumingkuan commented 4 years ago

For checking if the first statements (which can be container statements) in both branches of if are exactly the same, shall we add a function like bool same_statements(IRNode *root1, IRNode *root2) in ir.h and implement it using visitors in taichi/analysis/?

yuanming-hu commented 4 years ago

Very good question. I need to think about this a little bit. One very important IR functionality is to test if two IRNodes are equivalent. IRNode can be not only one statement but also a hierarchy. We might need to use some hashing here.

yuanming-hu commented 4 years ago

A few things to think about here

xumingkuan commented 4 years ago

There are 3 kinds of solutions I thought about. Denote the number of statements in the container IRNode we want to test by $n$ (if it's not a container, then n=1).

  1. Do nothing more when modifying statements. Then it takes O(n) time to find two IRNode's are the same, and O(n) time in the worst case to find two IRNode's are different. I think in most cases, we can find two IRNode's are different in O(1).
  2. Spend O(depth) more time when modifying statements, where "depth" means the number of container statements directly or indirectly containing the modified statement. We can update Binary DNA's and the hash of it in O(1) for each container statement. (Note that if we only set a boolean variable to tell if the container statement is modified, it still takes O(1) for each container statement!) So we can find two IRNode's are different in O(1) in expectation, but we still need O(n) time to find two IRNode's are the same ---- Binary DNAs' length is Ω(n).
  3. Spend O(depth * log(n)) more time when modifying statements. Then we can find two IRNode's are the same in O(log(n)) with some fancy data structures.

To me, I prefer the 1st solution. I think it unacceptable to spend O(depth) more time whenever modifying statements, just to avoid the worst-case O(n) time finding if two IRNode's are different: we modify statements far more often than checking if two IRNode's are equivalent.

If there is a stage that statements don't change anymore, we can build data structures for comparing IRNode's then.

yuanming-hu commented 4 years ago

Thanks for the detailed analysis. I agree with your decision and we should probably go with the 1st solution.

Meanwhile, a very easy-to-implement (and slightly hacky) way to test if two statements are equivalent:

This should work for most cases (assuming the print_ir pass is doing a correct job) and can probably be implemented within 20 LoC.

xumingkuan commented 4 years ago

Thanks for the hacky way, but I want to implement a reject-fast solution. I think most of the queries will be of different IRNode's.

xumingkuan commented 4 years ago

Maybe I can implement a visitor to visit one of the IRNode's, while storing the corresponding IRNode in the visitor class?

yuanming-hu commented 4 years ago

Sounds good. I champion your decision :-)

Maybe I can implement a visitor to visit one of the IRNode's, while storing the corresponding IRNode in the visitor class?

Right, you have to use one IRNode to guide the other.

xumingkuan commented 4 years ago

I wonder if this IR is valid:

<f32 x1> $238 = alloca
<f32 x1> $197 = alloca
<f32 x1> $239 : local store [$238 <- $197]
<f32 x1> $199 = ...
<f32 x1> $200 : local store [$197 <- $199]
<f32 x1> $242 = local load [ [$238[0]]]
<f32 x1> $218 = local load [ [$242[0]]]

It causes simplify.cpp to crash because the alloca here https://github.com/taichi-dev/taichi/blob/24e76a14e3ebfc4a8ee7cc2b36d44030a75e226a/taichi/transforms/simplify.cpp#L479 is not an AllocaStmt when we are visiting $218.

yuanming-hu commented 4 years ago

Good question. LocalLoad must take Allocas as inputs. $218 is invalid.

xumingkuan commented 4 years ago

So shall we add TI_ASSERT(...->is<AllocaStmt>()); to LocalAddress::var and LocalStoreStmt::ptr in their constructors?

xumingkuan commented 4 years ago

Oh no, it's causing assertion failure even in the initial IR.

yuanming-hu commented 4 years ago

So shall we add TI_ASSERT(...->is<AllocaStmt>()); to LocalAddress::var and LocalStoreStmt::ptr in their constructors?

Good idea.

Oh no, it's causing assertion failure even in the initial IR.

Could you share with me more details?

xumingkuan commented 4 years ago

Test case: test_ad_if.py test_ad_if_mutable

Part of the change set:

  LocalAddress(Stmt *var, int offset) : var(var), offset(offset) {
    std::cout << "local address" << std::endl;
    TI_ASSERT(var->is<AllocaStmt>());
  }
...
  void flatten(VecStatement &ret) override {
    std::cout << "from flatten" << std::endl;
    ret.push_back(std::make_unique<LocalLoadStmt>(
        LocalAddress(current_block->lookup_var(id), 0)));
    stmt = ret.back().get();
  }

Output:

Before preprocessing:
@ti.kernel
def func(i: ti.i32):
    t = x[i]
    if t > 0:
        y[i] = t
    else:
        y[i] = 2 * t

After preprocessing:
def func():
  i = ti.decl_scalar_arg(ti.i32)
  t = ti.expr_init(ti.subscript(x, i))
  if 1:
    __cond = ti.chain_compare([t, 0], ['Gt'])
    ti.core.begin_frontend_if(ti.Expr(__cond).ptr)
    ti.core.begin_frontend_if_true()
    ti.subscript(y, i).assign(t)
    ti.core.pop_scope()
    ti.core.begin_frontend_if_false()
    ti.subscript(y, i).assign(2 * t)
    ti.core.pop_scope()

[I 04/06/20 18:22:47.127] [compile_to_offloads.cpp:taichi::lang::irpass::com
pile_to_offloads::<lambda_a9f5d9347feda29776c658d0949d74f7>::operator ()@17]
 Initial IR:
==========
kernel {
  $0 = alloca @tmp4
  @tmp4 = gbl load #@tmp0[arg[0]]
  $2 = alloca @tmp5
  @tmp5 = @tmp4
  $4 = alloca @tmp6
  @tmp6 = 0
  $6 = alloca @tmp7
  @tmp7 = 1
  if (@tmp7 & (@tmp5 > @tmp6)) {
    #@tmp2[arg[0]] = @tmp4
  } else {
    #@tmp2[arg[0]] = (@tmp4 * 2)
  }
}
==========
from flatten
local address
local address
[E 04/06/20 18:22:47.129] [taichi/ir/ir.h:taichi::lang::LocalAddress::LocalA
ddress@1687] var->is<AllocaStmt>()

Still finding where the second local address comes from now. Compiling ir.h takes minutes.

xumingkuan commented 4 years ago

Maybe I should do the assertion only when var != nullptr?

yuanming-hu commented 4 years ago

Maybe I should do the assertion only when var != nullptr?

I assume LocalAddress'es must not have null pointers, but it would be good to be defensive.

Actually, there's a piece of Windows debugging infrastructure we can do here: could you help integrate this piece of code into taichi/system/traceback.cpp? It will give you a stack of function calls. Currently on Windows you only have a error message instead of call stack when things crash. This makes debugging hard. Feel free to open up an issue/draft PR to track this.

#include <intrin.h>
#include <dbghelp.h>
#include <cstdio>
#include <vector>
#include <string>
#include <sstream>

#include "taichi/platform/windows/windows.h"

#pragma comment(lib, "dbghelp.lib")

//  https://gist.github.com/rioki/85ca8295d51a5e0b7c56e5005b0ba8b4
//
//  Debug Helpers
//
// Copyright (c) 2015 - 2017 Sean Farrell <sean.farrell@rioki.org>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
//

namespace dbg {
inline void trace(const char *msg, ...) {
  char buff[1024];

  va_list args;
  va_start(args, msg);
  vsnprintf(buff, 1024, msg, args);

  OutputDebugStringA(buff);

  va_end(args);
}

inline std::string basename(const std::string &file) {
  unsigned int i = file.find_last_of("\\/");
  if (i == std::string::npos) {
    return file;
  } else {
    return file.substr(i + 1);
  }
}

struct StackFrame {
  DWORD64 address;
  std::string name;
  std::string module;
  unsigned int line;
  std::string file;
};

inline std::vector<StackFrame> stack_trace() {
#if _WIN64
  DWORD machine = IMAGE_FILE_MACHINE_AMD64;
#else
  DWORD machine = IMAGE_FILE_MACHINE_I386;
#endif
  HANDLE process = GetCurrentProcess();
  HANDLE thread = GetCurrentThread();

  if (SymInitialize(process, NULL, TRUE) == FALSE) {
    DBG_TRACE(__FUNCTION__ ": Failed to call SymInitialize.");
    return std::vector<StackFrame>();
  }

  SymSetOptions(SYMOPT_LOAD_LINES);

  CONTEXT context = {};
  context.ContextFlags = CONTEXT_FULL;
  RtlCaptureContext(&context);

#if _WIN64
  STACKFRAME frame = {};
  frame.AddrPC.Offset = context.Rip;
  frame.AddrPC.Mode = AddrModeFlat;
  frame.AddrFrame.Offset = context.Rbp;
  frame.AddrFrame.Mode = AddrModeFlat;
  frame.AddrStack.Offset = context.Rsp;
  frame.AddrStack.Mode = AddrModeFlat;
#else
  STACKFRAME frame = {};
  frame.AddrPC.Offset = context.Eip;
  frame.AddrPC.Mode = AddrModeFlat;
  frame.AddrFrame.Offset = context.Ebp;
  frame.AddrFrame.Mode = AddrModeFlat;
  frame.AddrStack.Offset = context.Esp;
  frame.AddrStack.Mode = AddrModeFlat;
#endif

  bool first = true;

  std::vector<StackFrame> frames;
  while (StackWalk(machine, process, thread, &frame, &context, NULL,
                   SymFunctionTableAccess, SymGetModuleBase, NULL)) {
    StackFrame f = {};
    f.address = frame.AddrPC.Offset;

#if _WIN64
    DWORD64 moduleBase = 0;
#else
    DWORD moduleBase = 0;
#endif

    moduleBase = SymGetModuleBase(process, frame.AddrPC.Offset);

    char moduelBuff[MAX_PATH];
    if (moduleBase &&
        GetModuleFileNameA((HINSTANCE)moduleBase, moduelBuff, MAX_PATH)) {
      f.module = basename(moduelBuff);
    } else {
      f.module = "Unknown Module";
    }
#if _WIN64
    DWORD64 offset = 0;
#else
    DWORD offset = 0;
#endif
    char symbolBuffer[sizeof(IMAGEHLP_SYMBOL) + 255];
    PIMAGEHLP_SYMBOL symbol = (PIMAGEHLP_SYMBOL)symbolBuffer;
    symbol->SizeOfStruct = (sizeof IMAGEHLP_SYMBOL) + 255;
    symbol->MaxNameLength = 254;

    if (SymGetSymFromAddr(process, frame.AddrPC.Offset, &offset, symbol)) {
      f.name = symbol->Name;
    } else {
      DWORD error = GetLastError();
      DBG_TRACE(__FUNCTION__ ": Failed to resolve address 0x%X: %u\n",
                frame.AddrPC.Offset, error);
      f.name = "Unknown Function";
    }

    IMAGEHLP_LINE line;
    line.SizeOfStruct = sizeof(IMAGEHLP_LINE);

    DWORD offset_ln = 0;
    if (SymGetLineFromAddr(process, frame.AddrPC.Offset, &offset_ln, &line)) {
      f.file = line.FileName;
      f.line = line.LineNumber;
    } else {
      DWORD error = GetLastError();
      DBG_TRACE(__FUNCTION__ ": Failed to resolve line for 0x%X: %u\n",
                frame.AddrPC.Offset, error);
      f.line = 0;
    }

    if (!first) {
      frames.push_back(f);
    }
    first = false;
  }

  SymCleanup(process);

  return frames;
}

inline void handle_assert(const char *func, const char *cond) {
  std::stringstream buff;
  buff << func << ": Assertion '" << cond << "' failed! \n";
  buff << "\n";

  std::vector<StackFrame> stack = stack_trace();
  buff << "Callstack: \n";
  for (unsigned int i = 0; i < stack.size(); i++) {
    buff << "0x" << std::hex << stack[i].address << ": " << stack[i].name << "("
         << std::dec << stack[i].line << ") in " << stack[i].module << "\n";
  }

  // please replace with std::printf
  MessageBoxA(NULL, buff.str().c_str(), "Assert Failed", MB_OK | MB_ICONSTOP);
  abort();
}

}  // namespace dbg
xumingkuan commented 4 years ago

I just checked that var is nullptr in the second local address.

xumingkuan commented 4 years ago

I just located that

<f32 x1> $242 = local load [ [$238[0]]]
<f32 x1> $218 = local load [ [$242[0]]]

is introduced in make_adjoint... Debugging.

yuanming-hu commented 4 years ago

Sounds like here? https://github.com/taichi-dev/taichi/blob/d28533c503a1e4fd1101b549d63c6fb540618be2/taichi/transforms/make_adjoint.cpp#L516

xumingkuan commented 4 years ago

This is in BackupSSA and I printed it and found that all auto allocas are indeed allocas, at least at that place -- otherwise it should trigger assertion failure.

I suspect the problem is in MakeAdjoint. Please check #726 when you are available (the output is so long that I opened a new issue for it).

yuanming-hu commented 4 years ago

Sounds good. I'm occupied until 11:59 PM but I'll take a look after that time.

xumingkuan commented 4 years ago

Currently, the following $47 cannot be eliminated:

<i32 x1> $2 = alloca
if $22 {
  <i32 x1> $47 : local store [$2 <- $46]
}
(nothing related to $2)

This is because $47 doesn't know that $2 will never be loaded.

There are 5 cases like this in test_ad_if_mutable, so we can reduce the number of statements by at least 10 (eliminating local store & alloca).

Describe the solution you'd like (if any) I want to implement a pass that analyzes allocas (for each alloca, do store forwarding and useless local store elimination), but I don't know if I should implement it in a new pass or in an existing pass.

(global tmp vars may be similar, but the Stmts are different so they can't be implemented together)

I find BasicBlockSimplify's function quite limited -- there are 3 of its main functions (common subexpression elimination, store forwarding, useless local store elimination) I want to upgrade.

xumingkuan commented 4 years ago

LocalLoadSearcher, LocalStoreSearcher, LocalStoreForwarder may be necessary for the upgraded common subexpression elimination/store forwarding/useless local store elimination passes. Shall we move them to analysis/?

yuanming-hu commented 4 years ago

Currently, the following $47 cannot be eliminated:

<i32 x1> $2 = alloca
if $22 {
  <i32 x1> $47 : local store [$2 <- $46]
}
(nothing related to $2)

This is because $47 doesn't know that $2 will never be loaded.

There are 5 cases like this in test_ad_if_mutable, so we can reduce the number of statements by at least 10 (eliminating local store & alloca).

Describe the solution you'd like (if any) I want to implement a pass that analyzes allocas (for each alloca, do store forwarding and useless local store elimination), but I don't know if I should implement it in a new pass or in an existing pass.

(global tmp vars may be similar, but the Stmts are different so they can't be implemented together)

I find BasicBlockSimplify's function quite limited -- there are 3 of its main functions (common subexpression elimination, store forwarding, useless local store elimination) I want to upgrade.

Thanks for spotting this. A new pass sounds better since no existing pass does this. Also I think the logic of this pass would be complex enough to justify the existence of itself.

LocalLoadSearcher, LocalStoreSearcher, LocalStoreForwarder may be necessary for the upgraded common subexpression elimination/store forwarding/useless local store elimination passes. Shall we move them to analysis/?

Sounds good!!

xumingkuan commented 4 years ago

I want to make use of AlgSimp::alg_is_one to eliminate $6 in this case:

<i32 x1> $5 = const [1]
$6 : while control nullptr, $5

(We can eliminate it even if mask is not nullptr, right?)

But should it be in the alg_simp pass?

archibate commented 4 years ago

alg_is_one

maybe alg_is_non_zero_constant :)

alg_simp pass?

I thought this is abour control flow not algebra expr level, so maybe not really related?

xumingkuan commented 4 years ago

maybe alg_is_non_zero_constant :)

Yes... It may be clearer if the type is u1. BTW what's the behavior of if 0.1 or while control ..., 0.1?

xumingkuan commented 4 years ago

I found

<i32 x1> $10 = const [1]
<i32 x1> $11 = cmp_gt $6 $9
<i32 x1> $12 = bit_and $10 $11

in some IRs, but it's hard to optimize if there are neither boolean types (u1) nor logical operations (logic_and).

Maybe another way to optimize it is to change this from expr_init(True) to expr_init(-1)... https://github.com/taichi-dev/taichi/blob/532ea3340e8c8201c97c768110be907038df7a17/python/taichi/lang/impl.py#L110

yuanming-hu commented 4 years ago

My feeling is that we should systematically fix this after we have u1 introduced...

xumingkuan commented 4 years ago

benchmark20200422

The geometric mean of the optimization factor on the number of statements among all tests is 1.068 now.

yuanming-hu commented 4 years ago

Cool! I assume bigger means more optimized in the table. I'm curious about which test gives you < 0.75 number, and which are > 1.5?

(PS: it's almost always good to use xlabel and ylabel and title to make your plots easier to understand.)

xumingkuan commented 4 years ago

Tests with > 1.5 boost:

test_ad_if__test_ad_if : 1.5348837209302326
test_ad_if__test_ad_if_mutable : 2.0485436893203883
test_ad_if__test_ad_if_parallel : 1.9245283018867925
test_ad_if__test_ad_if_parallel_complex : 1.625
test_continue__test_kernel_continue : 1.5844155844155845

(test_ad_if__test_ad_if_mutable should have been optimized from 105 statements to 26 statements, but there are other kernels causing ~100 statements in total that can hardly be optimized.)

Tests that become much worse (< 0.75):

test_tensor_dimensionality__test_dimensionality : 0.7463768115942029
test_tensor_reflection__test_POT : 0.7272727272727273
yuanming-hu commented 4 years ago

Thanks for the report. The bad news is that we have overfit to the test_ad_if series; the good news is there are still a lot of space to improve here...

xumingkuan commented 4 years ago

I just found that test_tensor_reflection__test_POT has no kernels in it. Figuring out what's wrong...

xumingkuan commented 4 years ago

For test_tensor_dimensionality, it's indeed optimized: before:

kernel {
  $0 = offloaded range_for(0, 256) block_dim=adaptive {
    <i32 x1> $1 = const [0]
    <i32 x1> $2 = loop index 0
    <i32 x1> $3 = bit_extract($2 + 0, 7~8)
    <i32 x1> $4 = const [1]
    <i32 x1> $5 = mul $3 $4
    <i32 x1> $6 = add $1 $5
    <i32 x1> $7 = bit_extract($2 + 0, 6~7)
    <i32 x1> $8 = mul $7 $4
    <i32 x1> $9 = add $1 $8
    <i32 x1> $10 = bit_extract($2 + 0, 5~6)
    <i32 x1> $11 = mul $10 $4
    <i32 x1> $12 = add $1 $11
    <i32 x1> $13 = bit_extract($2 + 0, 4~5)
    <i32 x1> $14 = mul $13 $4
    <i32 x1> $15 = add $1 $14
    <i32 x1> $16 = bit_extract($2 + 0, 3~4)
    <i32 x1> $17 = mul $16 $4
    <i32 x1> $18 = add $1 $17
    <i32 x1> $19 = bit_extract($2 + 0, 2~3)
    <i32 x1> $20 = mul $19 $4
    <i32 x1> $21 = add $1 $20
    <i32 x1> $22 = bit_extract($2 + 0, 1~2)
    <i32 x1> $23 = mul $22 $4
    <i32 x1> $24 = add $1 $23
    <i32 x1> $25 = bit_extract($2 + 0, 0~1)
    <i32 x1> $26 = mul $25 $4
    <i32 x1> $27 = add $1 $26
    <i32 x1> $28 = add $6 $9
    <i32 x1> $29 = add $28 $12
    <i32 x1> $30 = add $29 $15
    <i32 x1> $31 = add $30 $18
    <i32 x1> $32 = add $31 $21
    <i32 x1> $33 = add $32 $24
    <i32 x1> $34 = add $33 $27
    <gen*x1> $35 = get root
    <i32 x1> $36 = linearized(ind {}, stride {})
    <gen*x1> $37 = [S0root][root]::lookup($35, $36) activate = false
    <gen*x1> $38 = get child [S0root->S1dense] $37
    <i32 x1> $39 = bit_extract($6 + 0, 0~1)
    <i32 x1> $40 = bit_extract($9 + 0, 0~1)
    <i32 x1> $41 = bit_extract($12 + 0, 0~1)
    <i32 x1> $42 = bit_extract($15 + 0, 0~1)
    <i32 x1> $43 = bit_extract($18 + 0, 0~1)
    <i32 x1> $44 = bit_extract($21 + 0, 0~1)
    <i32 x1> $45 = bit_extract($24 + 0, 0~1)
    <i32 x1> $46 = bit_extract($27 + 0, 0~1)
    <i32 x1> $47 = linearized(ind {$39, $40, $41, $42, $43, $44, $45, $46},
stride {2, 2, 2, 2, 2, 2, 2, 2})
    <gen*x1> $48 = [S1dense][dense]::lookup($38, $47) activate = false
    <i32*x1> $49 = get child [S1dense->S2place_i32] $48
    <i32 x1> $50 = atomic add($49, $34)
    <i32*x1> $51 = get child [S1dense->S3place_i32] $48
    <i32 x1> $52 = atomic add($51, $6)
  }
}

after:

kernel {
  $0 = offloaded range_for(0, 256) block_dim=adaptive {
    <i32 x1> $1 = loop index 0
    <i32 x1> $2 = bit_extract($1 + 0, 7~8)
    <i32 x1> $3 = bit_extract($1 + 0, 6~7)
    <i32 x1> $4 = bit_extract($1 + 0, 5~6)
    <i32 x1> $5 = bit_extract($1 + 0, 4~5)
    <i32 x1> $6 = bit_extract($1 + 0, 3~4)
    <i32 x1> $7 = bit_extract($1 + 0, 2~3)
    <i32 x1> $8 = bit_extract($1 + 0, 1~2)
    <i32 x1> $9 = bit_extract($1 + 0, 0~1)
    <i32 x1> $10 = add $2 $3
    <i32 x1> $11 = add $10 $4
    <i32 x1> $12 = add $11 $5
    <i32 x1> $13 = add $12 $6
    <i32 x1> $14 = add $13 $7
    <i32 x1> $15 = add $14 $8
    <i32 x1> $16 = add $15 $9
    <gen*x1> $17 = get root
    <i32 x1> $18 = const [0]
    <gen*x1> $19 = [S0root][root]::lookup($17, $18) activate = false
    <gen*x1> $20 = get child [S0root->S1dense] $19
    <i32 x1> $21 = bit_extract($2 + 0, 0~1)
    <i32 x1> $22 = bit_extract($3 + 0, 0~1)
    <i32 x1> $23 = bit_extract($4 + 0, 0~1)
    <i32 x1> $24 = bit_extract($5 + 0, 0~1)
    <i32 x1> $25 = bit_extract($6 + 0, 0~1)
    <i32 x1> $26 = bit_extract($7 + 0, 0~1)
    <i32 x1> $27 = bit_extract($8 + 0, 0~1)
    <i32 x1> $28 = bit_extract($9 + 0, 0~1)
    <i32 x1> $29 = const [2]
    <i32 x1> $30 = mul $27 $29
    <i32 x1> $31 = add $28 $30
    <i32 x1> $32 = const [4]
    <i32 x1> $33 = mul $26 $32
    <i32 x1> $34 = add $31 $33
    <i32 x1> $35 = const [8]
    <i32 x1> $36 = mul $25 $35
    <i32 x1> $37 = add $34 $36
    <i32 x1> $38 = const [16]
    <i32 x1> $39 = mul $24 $38
    <i32 x1> $40 = add $37 $39
    <i32 x1> $41 = const [32]
    <i32 x1> $42 = mul $23 $41
    <i32 x1> $43 = add $40 $42
    <i32 x1> $44 = const [64]
    <i32 x1> $45 = mul $22 $44
    <i32 x1> $46 = add $43 $45
    <i32 x1> $47 = const [128]
    <i32 x1> $48 = mul $21 $47
    <i32 x1> $49 = add $46 $48
    <gen*x1> $50 = [S1dense][dense]::lookup($20, $49) activate = false
    <i32*x1> $51 = get child [S1dense->S2place_i32] $50
    <i32 x1> $52 = atomic add($51, $16)
    <i32*x1> $53 = get child [S1dense->S3place_i32] $50
    <i32 x1> $54 = atomic add($53, $2)
  }
}

It's just lowering linearize causing too many statements.

xumingkuan commented 4 years ago

Well... the $21-$28 here is just the same as $2-$9, isn't it?

yuanming-hu commented 4 years ago

It's just lowering linearize causing too many statements.

I see :-) People sometimes use a cost model to assign, say linearize higher weight.

Well... the $21-$28 here is just the same as $2-$9, isn't it?

Right, we can add a special optimization for a bitextract that takes as input another bitextract.