tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[TASK] TextPrinter Support for Unified TVM IRs #48

Closed tqchen closed 4 years ago

tqchen commented 4 years ago

So far we have a text printer for relay. which allows us to print an IRModule into text format. On the TIR side, we still relies on the ReprPrinter.

This is issue is for upgrading the text printer so that we can print an IRModule that include PrimFunc(tir::Function in the upstream) as a text format. This will help us to enhance the demo experience.

Possible Design Points

Ideally we want to land a version in the mainline in about two to three weeks. @spectrometerHBH please see if it is possible for you and @Hzfengsy to coordinate a format and land a version, then we can pull back to the tensorIR

spectrometerHBH commented 4 years ago

So ideally we want text format & hybrid script to be two different printing methods for TIR which are both parseable?

tqchen commented 4 years ago

yes. but as a path for progression, we can

spectrometerHBH commented 4 years ago

What is hybrid bidirection? Bidirectional translation between text format & hybrid?

tqchen commented 4 years ago

Sorry, i meant hybrid parsing and printing

Hzfengsy commented 4 years ago

Generally, it is ok to use the current ReprPrinter as long as we do some minor mutation. However, there are some details to be determined.

Function Syntax

  1. As we know, relay function has its arg and return type.
    fn(%x : Tensor[(10, 10), float32], %y : Tensor[(10, 10), float32])
               -> Tensor[(10, 10), float32] {
        add(%x, %y)
    }

    Of course, TIR can do the same thing to show the parameter type. The problem is TIR does not have a return type even the return value. It directly changes the buffer by mutating its element. Here are two options for this:

    • Just make it without a return value. Pros: More natural to low-level code. Cons: It is not so unified in syntax, and also, the vars are no longer immutable in the "unified" syntax.
    • Change the behavior of TIR function to make it has a return value and type. Pros: Unified and make the vars immutable. Cons: It is strange that a low-level code to return a buffer. Somehow a TIR function may have more than one return buffer or even change the input buffer.

I think both are not the perfect solution but I have no other ideas. Would love to listen to your opinions

Annotation

I'd like to add a kind of annotation in function to show whether it is a Relay or TIR function. One possible syntax would be

@relay
fn(%x, %y) { add(%x, %y) }
@tir
fn(%x, %y) { X[0] = Y[0] }

Function Body

There is no doubt that the relay function and TIR function have different stmt, expr. (e.g. Let Binding in relay and BufferLoad in TIR). So I think we can ignore the difference of the function body but try to make the function itself unified.

@spectrometerHBH @tqchen

tqchen commented 4 years ago
tqchen commented 4 years ago

FYI, I am in the process of refactoring out tensor field(func) in Realize Provide and replace them by BufferRealize BufferStore and BufferLoad. So we don’t need to consider supporting these nodes in the text format

tqchen commented 4 years ago

Related change in the upstream https://github.com/apache/incubator-tvm/pull/5372

spectrometerHBH commented 4 years ago

A comparison between repr's format and text format(which is problematic now)

[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
PrimFunc([A0, A1, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  // attr [B.v0] storage_scope = "global"
  allocate B.v0[float32 * n]
  // attr [B.v1] storage_scope = "global"
  allocate B.v1[float32 * n]
  for (i, 0, m) {
    for (j, 0, n) {
      B.v0[j] = (A0[((i*stride) + (j*stride))] + 2f)
      B.v1[j] = (A0[((i*stride) + (j*stride))]*3f)
    }
    for (j, 0, n) {
      C[((i*stride) + (j*stride))] = (A1[((i*stride) + (j*stride))] + B.v0[j])
    }
  }
}

[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
class Module:
  PrimFunc(%A0: Buffer([%m, %n], "float32", name=%A01),
           %A1: Buffer([%m, %n], "float32", name=%A11),
           %C: Buffer([%m, %n], "float32", name=%C1)) {
    attr [%B.v0] "storage_scope" = "global" {
      allocate(%B.v0, "float32", [%n]) if "bool"(1) {
        attr [%B.v1] "storage_scope" = "global" {
          allocate(%B.v1, "float32", [%n]) if "bool"(1) {
            for (%i, 0, %m) "serial" {
              for (%j, 0, %n) "serial" {
                %B.v0[%j] = (%A0[((%i*%stride) + (%j*%stride1))] + "float32"(2)) if "bool"(1)
                %B.v1[%j] = (%A0[((%i*%stride) + (%j*%stride1))]*"float32"(3)) if "bool"(1)
              }
              for (%j1, 0, %n) "serial" {
                %C[((%i*%stride2) + (%j1*%stride3))] = (%A1[((%i*%stride4) + (%j1*%stride5))] + %B.v0[%j1]) if "bool"(1)
              }
            }
          }
        }
      }
    }
  }

Several Problems

1. The name of Vars

For readability, the name_hint of var is useful. But B.v0 may not be a legal identifier for parser since it contains .. And the user may give an arbitrary name.

2. dtype of Vars

I'm not very clear where Var is being used in the upstream AST. As long as I can see now

  1. used in buffer shape declare
  2. loop var
  3. stride var(stride above, but I'm not clear what they are)
  4. buffer var
  5. Iter var(but iter var seems to only appear in attr, otherwise we use the var in iter_var)

When printing a var, the best way is to just give its name_hint, but in principle we have to determine its dtype as well. But in most cases, dtype doesn't matter. I'm not sure about this point.

tqchen commented 4 years ago
tqchen commented 4 years ago

Some additional thoughts after seeing the example.

A strawman, I am not too happy with the syntax

attr ("storage_scope", %Bv0) = "global"
spectrometerHBH commented 4 years ago

For vars like %m and %n, should we print their Type/dtype? Their declaration is outside AST.

%A0: Buffer([%m, %n], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)

One way is to print Type for the first time we encounter the Var

%A0: Buffer([%m : int32, %n : int32], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)

Another way is to declare first their defs somewhere else.

The third way is to put it in meta.

I think all the ways are somewhat strange.

And for vars like stride

%A0[((%i*%stride) + (%j*%stride1))

It looks like more a placeholder for unknown buffer shape. Does their Type/dtype matter?

tqchen commented 4 years ago

I think the first encouter or declare def somewhere else makes sense. Formally, we should not put buffer_map in the argument list(that makes Buffer a type) but they are not types, we might want to have some seciton like c++'s initializer list (the corresponding hybrid script part is the buffer bind)

spectrometerHBH commented 4 years ago

It seems that relay's text printer doesn't support print PrimType and PointerType. Will this be fixed in the upstream?

tqchen commented 4 years ago

I think it will be up to us to fix them :)

spectrometerHBH commented 4 years ago

Discussion on printing Buffer/Var, META usage&recover

An inevitable problem in text format is how to print buffer_map(Map<tir::Var, Buffer>). I will expand my discussions based on this problem. cc @tqchen @Hzfengsy

PrimFunc(%A0: handle, %A1: handle, %C: handle)
  %C: Buffer(%C_1, [%m, %n], float32, [%stride, %stride_1], %C_2)
  %A0: Buffer(%A0_1, [%m, %n], float32, [%stride_2, %stride_3], %A0_2)
  %A1: Buffer(%A1_1, [%m, %n], float32, [%stride_4, %stride_5], %A1_2) {
  attr [%B_v0] "storage_scope" = "global" {
    allocate(%B_v0, float32, [%n])  {
      attr [%B_v1] "storage_scope" = "global" {
        allocate(%B_v1, float32, [%n])  {
          for (%i, 0, %m)  {
            for (%j, 0, %n)  {
              %B_v0[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))] + float32(2))
              %B_v1[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))]*float32(3))
            }
            for (%j_1, 0, %n)  {
              %C_1[((%i*%stride) + (%j_1*%stride_1))] = (%A1_1[((%i*%stride_4) + (%j_1*%stride_5))] + %B_v0[%j_1])
            }
          }
        }
      }
    }
  }
}  

C0. Print all Buffer information out

Pros:

If we avoid using META in printing, then the text format is fully writable and modifiable to the user, as long as the user follows the syntax rules.

Cons:

It's really long to print all Buffer info, which hurts readability.

C1. Put Buffer in META

Pros:

Clearer for buffer_map printing

Cons:

  1. The user cannot write a text format freely since we use META.
  2. For vars like %m, %n, %stride, they are put into META along with Buffer, for their usage below, we have to print meta[Var][0], which makes the usage hard to read.

I further propose several ways to deal with the cons of C1.

C1.1 Scan the Vars in Buffer in advance and store a name map for them into META

When printing, we will give each Var we encouneter a unique name_hint. To be able to print stride instead of meta[Var][0]. We can store a Map<std:: string, Var> into META for all the vars appeared in Buffer declaration. When parsing, we look up the name map when encounter stride to retrieve the right Var from META. Cons: Hurts the writability of text format since buffer declarations lie in META.

C1.2 Scan all the Vars in AST and store a name map for them

Consider such a scenario in general cases: A Var is declared first, but its usage is going to be put in META along with some other non-printable node containing it. To print correctly, we have to collect all vars that are going to be put in META, and C1.1 is a special case to handle Vars in Buffer. If in the future we want to put more nodes into META, we have to keep this problem in mind.

But we can do this more aggressively to totally alleviate this problem, we can just scan all the Vars in AST and store a name map for them into META. Then for all the Var occurances, we can just print a name_hint no matter how we use this Var.

Pros: Simple and works generally. Note that it also works if we don't put the Buffer into META. Cons: Hurts the writability most, since all the declaration of Vars lies in META.

spectrometerHBH commented 4 years ago

I will land a version that avoids using META then.

spectrometerHBH commented 4 years ago

Is GetType always return PrimType or PointerType? Is there an API to convert Type to dtype?

tqchen commented 4 years ago

Type to dtype: https://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/op.h#L70 (note that it is always safe to construct a Var with Type, because Type is more fine grained). It is fine to print a Var without Type(only with dtype) but load it back containing Type.

We cannot always deduce Type from dtype though

spectrometerHBH commented 4 years ago

btw, if the identifier doesn't begin with %, the lexer will confuse it with CNAME used in meta. So I will keep % atm.

tqchen commented 4 years ago

I see, we will need to at least print mod as truncmod instead of % then.

spectrometerHBH commented 4 years ago

I've implemented a printer&parser for my current syntax. I'll attach the g4 file below for discussions in our next meeting. Note that it hasn't support META in Attr, and Attr(and DictAttr) is the only place that will result in nodes being put in META.

Several Points:

  1. load needs a dtype as a parameter of the constructor, and I haven't decided the syntax.
  2. I've tested several small test cases. But I have no clear idea of how to test it thoroughly. Now the way I can see is to test it using ir_builder. Also maybe I can try the operators provided in topi.
  3. buffers stores complete info of all the buffers appears in IR. But we can have more ideas on how to print it
spectrometerHBH commented 4 years ago
PrimFunc(%A0_1: handle, %A1_1: handle, %C_1: handle) -> ()
  buffers={%A1: Buffer(%A1_2: handle, float32, [%m: int32, %n: int32], [%stride: int32, %stride_1: int32], 0, "global", 128, 1, "auto"),
           %C: Buffer(%C_2: handle, float32, [%m, %n], [%stride_2: int32, %stride_3: int32], 0, "global", 128, 1, "auto"),
           %A0: Buffer(%A0_2: handle, float32, [%m, %n], [%stride_4: int32, %stride_5: int32], 0, "global", 128, 1, "auto")}
  buffer_map={%C_1: %C, %A0_1: %A0, %A1_1: %A1} {
  attr [%B_v0: handle] "storage_scope" = "global" {
    allocate(%B_v0, float32, [%n])  {
      attr [%B_v1: handle] "storage_scope" = "global" {
        allocate(%B_v1, float32, [%n])  {
          for (%i: int32, 0, %m)  {
            for (%j: int32, 0, %n)  {
              %B_v0[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))] + float32(2))
              %B_v1[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))]*float32(3))
            }
            for (%j_1: int32, 0, %n)  {
              %C_2[((%i*%stride_2) + (%j_1*%stride_3))] = (%A1_2[((%i*%stride) + (%j_1*%stride_1))] + %B_v0[%j_1])
            }
          }
        }
      }
    }
  }
}
grammar tir;

module
    : function* METADATA?
    ;

function
    : PRIMFUNC '(' varNode (',' varNode)* ')' '->' typeExpr
      'buffers' '=' '{' buffer_list?  '}'
      'buffer_map' '=' '{' buffer_map_list? '}'
       body
    ;

buffer_list
    : IDENTIFIER ':' bufferNode (',' IDENTIFIER ':' bufferNode)*
    ;

buffer_map_list
    : IDENTIFIER ':' IDENTIFIER (',' IDENTIFIER ':' IDENTIFIER)*
    ;

body
    : '{' stmtNode* '}'
    ;

node
    : primExprNode
    | stmtNode
    | rangeNode
    | commReducerNode
    | arrayNode
    | itervarNode
    | bufferNode
    | meta
    ;

arrayNode
    : '[' node (',' node)* ']'
    ;

stmtNode
    : LET varNode '=' primExprNode body # letStmt
    | ATTR '[' (varNode|itervarNode) ']' STRINGIMM '=' primExprNode body # attrStmt
    | ASSERT '(' primExprNode ',' primExprNode ')' body # assertStmt
    | varNode arrayNode '=' primExprNode (IF primExprNode)? # storeStmt
    | ALLOCATE '(' varNode ',' DTYPE ',' arrayNode ')' (IF primExprNode)? body # allocateStmt
    | FREE '(' varNode ')' # freeStmt
    | REALIZE '(' IDENTIFIER ',' arrayNode ')' (IF primExprNode)? body # bufferRealizeStmt
    | IF primExprNode body (ELSE body)? # ifThenElseStmt
    | EVALUATE '(' primExprNode ')' # evaluateStmt
    | FOR '(' varNode ',' primExprNode ',' primExprNode ')' STRINGIMM? body # forStmt
    | PREFETCH '(' IDENTIFIER ',' arrayNode ')' # prefetchStmt
    ;

primExprNode
    : CAST '(' DTYPE ',' primExprNode ')' # castExpr
    | FLOORDIV '(' primExprNode ',' primExprNode ')' # floordivExpr
    | FLOORMOD '(' primExprNode ',' primExprNode ')' # floormodExpr
    | MIN '(' primExprNode ',' primExprNode ')' # minExpr
    | MAX '(' primExprNode ',' primExprNode ')' # maxExpr
    | SELECT '(' primExprNode ',' primExprNode ',' primExprNode ')' # selectExpr
    | varNode arrayNode (IF primExprNode)? # loadExpr
    | RAMP '(' primExprNode ',' primExprNode ',' INT ')' # rampExpr
    | BROADCAST '(' primExprNode ',' INT ')' # broadcastExpr
    | LET varNode '=' primExprNode IN primExprNode # letExpr
    | CALL '(' DTYPE ',' STRINGIMM ',' arrayNode ',' STRINGIMM ',' INT ')' # callExpr
    | SHUFFLE '(' arrayNode ',' arrayNode ')' # shuffleExpr
    | REDUCE '(' commReducerNode ',' arrayNode ',' arrayNode ',' INT ')' (IF primExprNode)? # reduceExpr
    | <assoc=right> op='!' primExprNode # notExpr
    | src1 = primExprNode op=('*' | '/' | '%') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('+' | '-') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('<' | '>') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('<=' | '>=') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('==' | '!=') src2 = primExprNode # binExpr
    | src1 = primExprNode op='&&' src2 = primExprNode # binExpr
    | src1 = primExprNode op='||' src2 = primExprNode # binExpr
    | varNode # varExpr
    | meta # metaExpr
    | immediate # immExpr
    | '(' primExprNode ')' # parenExpr
    ;

varNode
    : IDENTIFIER (':' typeExpr)?
    ;

commReducerNode
    : COMMREDUCER '(' arrayNode ',' arrayNode ',' arrayNode ',' arrayNode ')'
    ;

bufferNode
    : BUFFER '(' varNode ',' DTYPE ',' arrayNode ',' arrayNode ',' primExprNode ','
      STRINGIMM ',' INT ',' INT ',' STRINGIMM ')'
    ;

rangeNode
    : primExprNode ':' primExprNode
    ;

itervarNode
    : ITERVAR '(' varNode ',' '[' rangeNode ']' ',' STRINGIMM ',' STRINGIMM ')'
    ;

immediate
    : STRINGIMM # stringImm
    | DTYPE '(' INT ')' # intImm
    | DTYPE '(' FLOATIMM ')' # floatImm
    | INT # int32Imm
    | TRUELITERAL # trueLiteral
    | FALSELITERAL # falseLiteral
    ;

meta
    : 'meta' '[' CNAME ']' '[' INT ']'
    ;

// --- Type
typeExpr
  : '(' ')'                                                                # tupleType
  | DTYPE                                                                  # primType
  | 'Pointer' '(' typeExpr ')'                                             # pointerType
  ;

METADATA: 'METADATA:' .*;

// --- Reserved words

MUL : '*' ;
DIV : '/' ;
ADD : '+' ;
SUB : '-' ;
MOD : '%' ;
LT : '<' ;
GT : '>' ;
LE : '<=' ;
GE : '>=' ;
EQ : '==' ;
NE : '!=' ;
AND : '&&' ;
OR : '||' ;

IN : 'in';
PRIMFUNC : 'PrimFunc';
IF : 'if';
ELSE : 'else';
TRUELITERAL : 'true';
FALSELITERAL : 'false';
LET : 'let';
ATTR : 'attr';
ASSERT : 'assert';
ALLOCATE : 'allocate';
FREE : 'free';
REALIZE : 'realize';
EVALUATE : 'evaluate';
FOR : 'for';
PREFETCH : 'prefetch';
CAST : 'cast';
FLOORDIV : 'floordiv';
FLOORMOD : 'floormod';
MIN : 'min';
MAX : 'max';
SELECT : 'select';
LOAD : 'load';
RAMP : 'ramp';
BROADCAST : 'broadcast';
CALL : 'call';
SHUFFLE : 'shuffle';
REDUCE : 'reduce';
BUFFER : 'Buffer';
COMMREDUCER : 'comm_reducer';
ITERVAR : 'IterVar';

fragment DIGIT
    :   [0-9]
    ;

fragment NAT
    :   DIGIT+
    ;

INT
    :   '-'? DIGIT+
    ;

fragment EXP
    :   [eE] [+\-]? NAT
    ;

FLOATIMM
    :   INT ('.' NAT)? EXP?
    ;

fragment LETTER
    :   [a-zA-Z]
    ;

STRINGIMM
    :   '"' ('\\n' | '\\\\' | '\\"' | .)*? '"'
    ;

DTYPE
    : 'float' NAT ('x' NAT)?
    | 'int' NAT ('x' NAT)?
    | 'uint' NAT ('x' NAT)?
    | 'bool'
    | 'handle'
    ;

IDENTIFIER
    :   '%' [a-zA-Z] + [a-zA-Z_0-9]*
    ;

CNAME : ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;

WHITESPACE
    :   [ \t\n\r]+ -> skip
    ;
spectrometerHBH commented 4 years ago

I tried to put vars declarations ahead of buffers, but I think it doesn' look very good.

vars = {int32: [%stride_3", "%stride_2", "%stride_4", "%stride_1", "%n", "%stride_5", "%m", "%stride", "%j]}

I think we'd better use the first encounter declaration style.

To shorten the declaration of a buffer, using kw param seems good

  buffers = {%A1: Buffer(%A1_2, float32, [%m : int32, %n : int32], [%stride_2 : int32, %stride_3 : int32]),
           %C: Buffer(%C_2, float32, [%m, %n], [%stride_4 : int32, %stride_5 : int32]),
           %A0: Buffer(%A0_2, float32, [%m, %n], [%stride : int32, %stride_1 : int32])}
tqchen commented 4 years ago

how about something likevars= [%stride_3 : int32, %stride_0:int32] ?

The first encouter also looks ok. We could also introduce default types for indices(i32 and perhaps change to i64 in the future

spectrometerHBH commented 4 years ago

Example for Conv on GPU

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
           W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
           A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
  buffer_map = {B_1: B, A_1: A, W_1: W} {
  attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [B.local: handle] "storage_scope" = "local";
  allocate(B.local, float32, [64])  {
    attr [Apad.shared: handle] "storage_scope" = "shared";
    allocate(Apad.shared, float32, [512])  {
      attr [W.shared: handle] "storage_scope" = "shared";
      allocate(W.shared, float32, [512])  {
        attr [Apad.shared.local: handle] "storage_scope" = "local";
        allocate(Apad.shared.local, float32, [8])  {
          attr [W.shared.local: handle] "storage_scope" = "local";
          allocate(W.shared.local, float32, [8])  {
            attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 8;
            attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 4;
            attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent" = 8;
            attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
              for (ff.c.init: int32, 0, 4) {
                for (nn.c.init: int32, 0, 4) {
                  B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
                }
              }
              for (rc.outer: int32, 0, 32) {
                for (ry: int32, 0, 3) {
                  for (rx: int32, 0, 3) {
                    for (ax3.inner.outer: int32, 0, 2) {
                      Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)), 1, 4)] = call(float32x4, "tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z, 14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z, 14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536)) + (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4)) - 983040), 1, 4)]), broadcast(float32(0), 4)], "pure_intrin", 0)
                    }
                    for (ax3.inner.outer_1: int32, 0, 2) {
                      W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512)) + (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
                    }
                    for (rc.inner: int32, 0, 8) {
                      for (ax3: int32, 0, 4) {
                        Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64) + (threadIdx.x*4)) + ax3)])
                        Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64) + (threadIdx.x*4)) + ax3) + 32)])
                      }
                      for (ax3_1: int32, 0, 4) {
                        W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64) + (threadIdx.y*4)) + ax3_1)])
                        W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64) + (threadIdx.y*4)) + ax3_1) + 32)])
                      }
                      for (ff.c: int32, 0, 4) {
                        for (nn.c: int32, 0, 4) {
                          B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4) + nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
                          B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c + 4)])))
                          B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
                          B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c + 4)])))
                        }
                      }
                    }
                  }
                }
              }
              for (ff.inner.inner.inner: int32, 0, 4) {
                for (nn.inner.inner.inner: int32, 0, 4) {
                  B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)] = load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
                }
              }
            }
          }
        }
      }
    }
  }
}
spectrometerHBH commented 4 years ago

GEMM on CPU

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
           B: Buffer(B_2: handle, float32, [1024, 1024], []),
           A: Buffer(A_2: handle, float32, [1024, 1024], [])}
  buffer_map = {C_1: C, A_1: A, B_1: B} {
  attr [packedB: handle] "storage_scope" = "global";
  allocate(packedB, float32x32, [32768])  {
    for (x: int32, 0, 32) "parallel" {
      for (y: int32, 0, 1024) {
        packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024) + (x*32)), 1, 32)])
      }
    }
    for (x.outer: int32, 0, 32) "parallel" {
      attr [C.global: handle] "storage_scope" = "global";
      allocate(C.global, float32, [1024])  {
        for (y.outer: int32, 0, 32) {
          for (x.c.init: int32, 0, 32) {
            C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
          }
          for (k.outer: int32, 0, 256) {
            for (x.c: int32, 0, 32) {
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]), 32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
            }
          }
          for (x.inner: int32, 0, 32) {
            for (y.inner: int32, 0, 32) {
              C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32, C.global[((x.inner*32) + y.inner)])
            }
          }
        }
      }
    }
  }
}
spectrometerHBH commented 4 years ago

TensorCore for Conv

primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
           A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
           Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
  buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
  attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
  allocate(Conv.wmma.accumulator, float32, [2048])  {
    attr [Apad.shared: handle] "storage_scope" = "shared";
    allocate(Apad.shared, float16, [12288])  {
      attr [W.shared: handle] "storage_scope" = "shared";
      allocate(W.shared, float16, [12288])  {
        attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
        allocate(Apad.shared.wmma.matrix_a, float16, [512])  {
          attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
          allocate(W.shared.wmma.matrix_b, float16, [1024])  {
            attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
            attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
            attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
            attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
              for (n.c.init: int32, 0, 2) {
                for (o.c.init: int32, 0, 4) {
                  eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), float32(0)], handle, "intrin", 0))
                }
              }
              for (ic.outer: int32, 0, 8) {
                for (kh: int32, 0, 3) {
                  for (ax2: int32, 0, 3) {
                    for (ax3: int32, 0, 2) {
                      for (ax4.ax5.fused.outer: int32, 0, 8) {
                        attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                        Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and (1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)), load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
                      }
                    }
                  }
                  for (ax1: int32, 0, 3) {
                    for (ax2_1: int32, 0, 2) {
                      attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                      W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
                    }
                  }
                  for (ic.inner: int32, 0, 2) {
                    for (kw: int32, 0, 3) {
                      for (ax0: int32, 0, 2) {
                        eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                      }
                      for (ax3_1: int32, 0, 4) {
                        eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                      }
                      for (n.c: int32, 0, 2) {
                        for (o.c: int32, 0, 4) {
                          eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c)], handle, "intrin", 0))
                        }
                      }
                    }
                  }
                }
              }
              for (n.inner: int32, 0, 2) {
                for (o.inner: int32, 0, 4) {
                  eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin", 0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                }
              }
            }
          }
        }
      }
    }
  }
}