Closed tqchen closed 4 years ago
So ideally we want text format & hybrid script to be two different printing methods for TIR which are both parseable?
yes. but as a path for progression, we can
What is hybrid bidirection? Bidirectional translation between text format & hybrid?
Sorry, i meant hybrid parsing and printing
Generally, it is ok to use the current ReprPrinter as long as we do some minor mutation. However, there are some details to be determined.
fn(%x : Tensor[(10, 10), float32], %y : Tensor[(10, 10), float32])
-> Tensor[(10, 10), float32] {
add(%x, %y)
}
Of course, TIR can do the same thing to show the parameter type. The problem is TIR does not have a return type even the return value. It directly changes the buffer by mutating its element. Here are two options for this:
I think both are not the perfect solution but I have no other ideas. Would love to listen to your opinions
I'd like to add a kind of annotation in function to show whether it is a Relay or TIR function. One possible syntax would be
@relay
fn(%x, %y) { add(%x, %y) }
@tir
fn(%x, %y) { X[0] = Y[0] }
There is no doubt that the relay function and TIR function have different stmt, expr. (e.g. Let Binding in relay and BufferLoad in TIR). So I think we can ignore the difference of the function body but try to make the function itself unified.
@spectrometerHBH @tqchen
void
, which is represented by an empty tuple type(as in swift) int
, both are indicated by the ret_typeinline void func
, extern "c" int func
, perhaps we could add a keyword(e.g. ) primfunc fn
to indicate primfuncFYI, I am in the process of refactoring out tensor field(func) in Realize Provide and replace them by BufferRealize BufferStore and BufferLoad. So we don’t need to consider supporting these nodes in the text format
Related change in the upstream https://github.com/apache/incubator-tvm/pull/5372
A comparison between repr's format and text format(which is problematic now)
[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
PrimFunc([A0, A1, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
// attr [B.v0] storage_scope = "global"
allocate B.v0[float32 * n]
// attr [B.v1] storage_scope = "global"
allocate B.v1[float32 * n]
for (i, 0, m) {
for (j, 0, n) {
B.v0[j] = (A0[((i*stride) + (j*stride))] + 2f)
B.v1[j] = (A0[((i*stride) + (j*stride))]*3f)
}
for (j, 0, n) {
C[((i*stride) + (j*stride))] = (A1[((i*stride) + (j*stride))] + B.v0[j])
}
}
}
[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
class Module:
PrimFunc(%A0: Buffer([%m, %n], "float32", name=%A01),
%A1: Buffer([%m, %n], "float32", name=%A11),
%C: Buffer([%m, %n], "float32", name=%C1)) {
attr [%B.v0] "storage_scope" = "global" {
allocate(%B.v0, "float32", [%n]) if "bool"(1) {
attr [%B.v1] "storage_scope" = "global" {
allocate(%B.v1, "float32", [%n]) if "bool"(1) {
for (%i, 0, %m) "serial" {
for (%j, 0, %n) "serial" {
%B.v0[%j] = (%A0[((%i*%stride) + (%j*%stride1))] + "float32"(2)) if "bool"(1)
%B.v1[%j] = (%A0[((%i*%stride) + (%j*%stride1))]*"float32"(3)) if "bool"(1)
}
for (%j1, 0, %n) "serial" {
%C[((%i*%stride2) + (%j1*%stride3))] = (%A1[((%i*%stride4) + (%j1*%stride5))] + %B.v0[%j1]) if "bool"(1)
}
}
}
}
}
}
}
For readability, the name_hint of var is useful. But B.v0
may not be a legal identifier for parser since it contains .
. And the user may give an arbitrary name.
I'm not very clear where Var is being used in the upstream AST. As long as I can see now
stride
above, but I'm not clear what they are)When printing a var, the best way is to just give its name_hint, but in principle we have to determine its dtype as well. But in most cases, dtype doesn't matter. I'm not sure about this point.
.
to `` or other symboles Some additional thoughts after seeing the example.
"float32"(1)
can be 1.0f
, or float32(0)
%
prefix seems to be quite dense and confusing when used together with operators(e.g. %), perhaps we want to allow variable names without the prefix?A strawman, I am not too happy with the syntax
attr ("storage_scope", %Bv0) = "global"
For vars like %m
and %n
, should we print their Type/dtype? Their declaration is outside AST.
%A0: Buffer([%m, %n], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)
One way is to print Type for the first time we encounter the Var
%A0: Buffer([%m : int32, %n : int32], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)
Another way is to declare first their defs somewhere else.
The third way is to put it in meta.
I think all the ways are somewhat strange.
And for vars like stride
%A0[((%i*%stride) + (%j*%stride1))
It looks like more a placeholder for unknown buffer shape. Does their Type/dtype matter?
I think the first encouter or declare def somewhere else makes sense. Formally, we should not put buffer_map in the argument list(that makes Buffer a type) but they are not types, we might want to have some seciton like c++'s initializer list (the corresponding hybrid script part is the buffer bind)
It seems that relay's text printer doesn't support print PrimType and PointerType. Will this be fixed in the upstream?
I think it will be up to us to fix them :)
An inevitable problem in text format is how to print buffer_map(Map<tir::Var, Buffer>
). I will expand my discussions based on this problem. cc @tqchen @Hzfengsy
PrimFunc(%A0: handle, %A1: handle, %C: handle)
%C: Buffer(%C_1, [%m, %n], float32, [%stride, %stride_1], %C_2)
%A0: Buffer(%A0_1, [%m, %n], float32, [%stride_2, %stride_3], %A0_2)
%A1: Buffer(%A1_1, [%m, %n], float32, [%stride_4, %stride_5], %A1_2) {
attr [%B_v0] "storage_scope" = "global" {
allocate(%B_v0, float32, [%n]) {
attr [%B_v1] "storage_scope" = "global" {
allocate(%B_v1, float32, [%n]) {
for (%i, 0, %m) {
for (%j, 0, %n) {
%B_v0[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))] + float32(2))
%B_v1[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))]*float32(3))
}
for (%j_1, 0, %n) {
%C_1[((%i*%stride) + (%j_1*%stride_1))] = (%A1_1[((%i*%stride_4) + (%j_1*%stride_5))] + %B_v0[%j_1])
}
}
}
}
}
}
}
If we avoid using META in printing, then the text format is fully writable and modifiable to the user, as long as the user follows the syntax rules.
It's really long to print all Buffer info, which hurts readability.
Clearer for buffer_map printing
%m, %n, %stride
, they are put into META along with Buffer, for their usage below, we have to print meta[Var][0]
, which makes the usage hard to read.I further propose several ways to deal with the cons of C1.
When printing, we will give each Var we encouneter a unique name_hint. To be able to print stride
instead of meta[Var][0]
. We can store a Map<std:: string, Var> into META for all the vars appeared in Buffer declaration.
When parsing, we look up the name map when encounter stride
to retrieve the right Var from META.
Cons: Hurts the writability of text format since buffer declarations lie in META.
Consider such a scenario in general cases: A Var is declared first, but its usage is going to be put in META along with some other non-printable node containing it. To print correctly, we have to collect all vars that are going to be put in META, and C1.1 is a special case to handle Vars in Buffer. If in the future we want to put more nodes into META, we have to keep this problem in mind.
But we can do this more aggressively to totally alleviate this problem, we can just scan all the Vars in AST and store a name map for them into META. Then for all the Var occurances, we can just print a name_hint no matter how we use this Var.
Pros: Simple and works generally. Note that it also works if we don't put the Buffer into META. Cons: Hurts the writability most, since all the declaration of Vars lies in META.
I will land a version that avoids using META then.
Is GetType
always return PrimType
or PointerType
? Is there an API to convert Type to dtype?
Type to dtype: https://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/op.h#L70 (note that it is always safe to construct a Var with Type, because Type is more fine grained). It is fine to print a Var without Type(only with dtype) but load it back containing Type.
We cannot always deduce Type from dtype though
btw, if the identifier doesn't begin with %, the lexer will confuse it with CNAME used in meta. So I will keep % atm.
I see, we will need to at least print mod as truncmod instead of % then.
I've implemented a printer&parser for my current syntax. I'll attach the g4 file below for discussions in our next meeting. Note that it hasn't support META in Attr, and Attr(and DictAttr) is the only place that will result in nodes being put in META.
Several Points:
buffers
stores complete info of all the buffers appears in IR. But we can have more ideas on how to print itPrimFunc(%A0_1: handle, %A1_1: handle, %C_1: handle) -> ()
buffers={%A1: Buffer(%A1_2: handle, float32, [%m: int32, %n: int32], [%stride: int32, %stride_1: int32], 0, "global", 128, 1, "auto"),
%C: Buffer(%C_2: handle, float32, [%m, %n], [%stride_2: int32, %stride_3: int32], 0, "global", 128, 1, "auto"),
%A0: Buffer(%A0_2: handle, float32, [%m, %n], [%stride_4: int32, %stride_5: int32], 0, "global", 128, 1, "auto")}
buffer_map={%C_1: %C, %A0_1: %A0, %A1_1: %A1} {
attr [%B_v0: handle] "storage_scope" = "global" {
allocate(%B_v0, float32, [%n]) {
attr [%B_v1: handle] "storage_scope" = "global" {
allocate(%B_v1, float32, [%n]) {
for (%i: int32, 0, %m) {
for (%j: int32, 0, %n) {
%B_v0[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))] + float32(2))
%B_v1[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))]*float32(3))
}
for (%j_1: int32, 0, %n) {
%C_2[((%i*%stride_2) + (%j_1*%stride_3))] = (%A1_2[((%i*%stride) + (%j_1*%stride_1))] + %B_v0[%j_1])
}
}
}
}
}
}
}
grammar tir;
module
: function* METADATA?
;
function
: PRIMFUNC '(' varNode (',' varNode)* ')' '->' typeExpr
'buffers' '=' '{' buffer_list? '}'
'buffer_map' '=' '{' buffer_map_list? '}'
body
;
buffer_list
: IDENTIFIER ':' bufferNode (',' IDENTIFIER ':' bufferNode)*
;
buffer_map_list
: IDENTIFIER ':' IDENTIFIER (',' IDENTIFIER ':' IDENTIFIER)*
;
body
: '{' stmtNode* '}'
;
node
: primExprNode
| stmtNode
| rangeNode
| commReducerNode
| arrayNode
| itervarNode
| bufferNode
| meta
;
arrayNode
: '[' node (',' node)* ']'
;
stmtNode
: LET varNode '=' primExprNode body # letStmt
| ATTR '[' (varNode|itervarNode) ']' STRINGIMM '=' primExprNode body # attrStmt
| ASSERT '(' primExprNode ',' primExprNode ')' body # assertStmt
| varNode arrayNode '=' primExprNode (IF primExprNode)? # storeStmt
| ALLOCATE '(' varNode ',' DTYPE ',' arrayNode ')' (IF primExprNode)? body # allocateStmt
| FREE '(' varNode ')' # freeStmt
| REALIZE '(' IDENTIFIER ',' arrayNode ')' (IF primExprNode)? body # bufferRealizeStmt
| IF primExprNode body (ELSE body)? # ifThenElseStmt
| EVALUATE '(' primExprNode ')' # evaluateStmt
| FOR '(' varNode ',' primExprNode ',' primExprNode ')' STRINGIMM? body # forStmt
| PREFETCH '(' IDENTIFIER ',' arrayNode ')' # prefetchStmt
;
primExprNode
: CAST '(' DTYPE ',' primExprNode ')' # castExpr
| FLOORDIV '(' primExprNode ',' primExprNode ')' # floordivExpr
| FLOORMOD '(' primExprNode ',' primExprNode ')' # floormodExpr
| MIN '(' primExprNode ',' primExprNode ')' # minExpr
| MAX '(' primExprNode ',' primExprNode ')' # maxExpr
| SELECT '(' primExprNode ',' primExprNode ',' primExprNode ')' # selectExpr
| varNode arrayNode (IF primExprNode)? # loadExpr
| RAMP '(' primExprNode ',' primExprNode ',' INT ')' # rampExpr
| BROADCAST '(' primExprNode ',' INT ')' # broadcastExpr
| LET varNode '=' primExprNode IN primExprNode # letExpr
| CALL '(' DTYPE ',' STRINGIMM ',' arrayNode ',' STRINGIMM ',' INT ')' # callExpr
| SHUFFLE '(' arrayNode ',' arrayNode ')' # shuffleExpr
| REDUCE '(' commReducerNode ',' arrayNode ',' arrayNode ',' INT ')' (IF primExprNode)? # reduceExpr
| <assoc=right> op='!' primExprNode # notExpr
| src1 = primExprNode op=('*' | '/' | '%') src2 = primExprNode # binExpr
| src1 = primExprNode op=('+' | '-') src2 = primExprNode # binExpr
| src1 = primExprNode op=('<' | '>') src2 = primExprNode # binExpr
| src1 = primExprNode op=('<=' | '>=') src2 = primExprNode # binExpr
| src1 = primExprNode op=('==' | '!=') src2 = primExprNode # binExpr
| src1 = primExprNode op='&&' src2 = primExprNode # binExpr
| src1 = primExprNode op='||' src2 = primExprNode # binExpr
| varNode # varExpr
| meta # metaExpr
| immediate # immExpr
| '(' primExprNode ')' # parenExpr
;
varNode
: IDENTIFIER (':' typeExpr)?
;
commReducerNode
: COMMREDUCER '(' arrayNode ',' arrayNode ',' arrayNode ',' arrayNode ')'
;
bufferNode
: BUFFER '(' varNode ',' DTYPE ',' arrayNode ',' arrayNode ',' primExprNode ','
STRINGIMM ',' INT ',' INT ',' STRINGIMM ')'
;
rangeNode
: primExprNode ':' primExprNode
;
itervarNode
: ITERVAR '(' varNode ',' '[' rangeNode ']' ',' STRINGIMM ',' STRINGIMM ')'
;
immediate
: STRINGIMM # stringImm
| DTYPE '(' INT ')' # intImm
| DTYPE '(' FLOATIMM ')' # floatImm
| INT # int32Imm
| TRUELITERAL # trueLiteral
| FALSELITERAL # falseLiteral
;
meta
: 'meta' '[' CNAME ']' '[' INT ']'
;
// --- Type
typeExpr
: '(' ')' # tupleType
| DTYPE # primType
| 'Pointer' '(' typeExpr ')' # pointerType
;
METADATA: 'METADATA:' .*;
// --- Reserved words
MUL : '*' ;
DIV : '/' ;
ADD : '+' ;
SUB : '-' ;
MOD : '%' ;
LT : '<' ;
GT : '>' ;
LE : '<=' ;
GE : '>=' ;
EQ : '==' ;
NE : '!=' ;
AND : '&&' ;
OR : '||' ;
IN : 'in';
PRIMFUNC : 'PrimFunc';
IF : 'if';
ELSE : 'else';
TRUELITERAL : 'true';
FALSELITERAL : 'false';
LET : 'let';
ATTR : 'attr';
ASSERT : 'assert';
ALLOCATE : 'allocate';
FREE : 'free';
REALIZE : 'realize';
EVALUATE : 'evaluate';
FOR : 'for';
PREFETCH : 'prefetch';
CAST : 'cast';
FLOORDIV : 'floordiv';
FLOORMOD : 'floormod';
MIN : 'min';
MAX : 'max';
SELECT : 'select';
LOAD : 'load';
RAMP : 'ramp';
BROADCAST : 'broadcast';
CALL : 'call';
SHUFFLE : 'shuffle';
REDUCE : 'reduce';
BUFFER : 'Buffer';
COMMREDUCER : 'comm_reducer';
ITERVAR : 'IterVar';
fragment DIGIT
: [0-9]
;
fragment NAT
: DIGIT+
;
INT
: '-'? DIGIT+
;
fragment EXP
: [eE] [+\-]? NAT
;
FLOATIMM
: INT ('.' NAT)? EXP?
;
fragment LETTER
: [a-zA-Z]
;
STRINGIMM
: '"' ('\\n' | '\\\\' | '\\"' | .)*? '"'
;
DTYPE
: 'float' NAT ('x' NAT)?
| 'int' NAT ('x' NAT)?
| 'uint' NAT ('x' NAT)?
| 'bool'
| 'handle'
;
IDENTIFIER
: '%' [a-zA-Z] + [a-zA-Z_0-9]*
;
CNAME : ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
WHITESPACE
: [ \t\n\r]+ -> skip
;
I tried to put vars declarations ahead of buffers
, but I think it doesn' look very good.
vars = {int32: [%stride_3", "%stride_2", "%stride_4", "%stride_1", "%n", "%stride_5", "%m", "%stride", "%j]}
I think we'd better use the first encounter declaration style.
To shorten the declaration of a buffer, using kw param seems good
buffers = {%A1: Buffer(%A1_2, float32, [%m : int32, %n : int32], [%stride_2 : int32, %stride_3 : int32]),
%C: Buffer(%C_2, float32, [%m, %n], [%stride_4 : int32, %stride_5 : int32]),
%A0: Buffer(%A0_2, float32, [%m, %n], [%stride : int32, %stride_1 : int32])}
how about something likevars= [%stride_3 : int32, %stride_0:int32]
?
The first encouter also looks ok. We could also introduce default types for indices(i32 and perhaps change to i64 in the future
Example for Conv on GPU
primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
buffer_map = {B_1: B, A_1: A, W_1: W} {
attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [B.local: handle] "storage_scope" = "local";
allocate(B.local, float32, [64]) {
attr [Apad.shared: handle] "storage_scope" = "shared";
allocate(Apad.shared, float32, [512]) {
attr [W.shared: handle] "storage_scope" = "shared";
allocate(W.shared, float32, [512]) {
attr [Apad.shared.local: handle] "storage_scope" = "local";
allocate(Apad.shared.local, float32, [8]) {
attr [W.shared.local: handle] "storage_scope" = "local";
allocate(W.shared.local, float32, [8]) {
attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 8;
attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent" = 8;
attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
for (ff.c.init: int32, 0, 4) {
for (nn.c.init: int32, 0, 4) {
B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
}
}
for (rc.outer: int32, 0, 32) {
for (ry: int32, 0, 3) {
for (rx: int32, 0, 3) {
for (ax3.inner.outer: int32, 0, 2) {
Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)), 1, 4)] = call(float32x4, "tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z, 14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z, 14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536)) + (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4)) - 983040), 1, 4)]), broadcast(float32(0), 4)], "pure_intrin", 0)
}
for (ax3.inner.outer_1: int32, 0, 2) {
W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512)) + (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
}
for (rc.inner: int32, 0, 8) {
for (ax3: int32, 0, 4) {
Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64) + (threadIdx.x*4)) + ax3)])
Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64) + (threadIdx.x*4)) + ax3) + 32)])
}
for (ax3_1: int32, 0, 4) {
W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64) + (threadIdx.y*4)) + ax3_1)])
W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64) + (threadIdx.y*4)) + ax3_1) + 32)])
}
for (ff.c: int32, 0, 4) {
for (nn.c: int32, 0, 4) {
B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4) + nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c + 4)])))
B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c + 4)])))
}
}
}
}
}
}
for (ff.inner.inner.inner: int32, 0, 4) {
for (nn.inner.inner.inner: int32, 0, 4) {
B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)] = load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
}
}
}
}
}
}
}
}
}
GEMM on CPU
primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
B: Buffer(B_2: handle, float32, [1024, 1024], []),
A: Buffer(A_2: handle, float32, [1024, 1024], [])}
buffer_map = {C_1: C, A_1: A, B_1: B} {
attr [packedB: handle] "storage_scope" = "global";
allocate(packedB, float32x32, [32768]) {
for (x: int32, 0, 32) "parallel" {
for (y: int32, 0, 1024) {
packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024) + (x*32)), 1, 32)])
}
}
for (x.outer: int32, 0, 32) "parallel" {
attr [C.global: handle] "storage_scope" = "global";
allocate(C.global, float32, [1024]) {
for (y.outer: int32, 0, 32) {
for (x.c.init: int32, 0, 32) {
C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
}
for (k.outer: int32, 0, 256) {
for (x.c: int32, 0, 32) {
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]), 32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
}
}
for (x.inner: int32, 0, 32) {
for (y.inner: int32, 0, 32) {
C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32, C.global[((x.inner*32) + y.inner)])
}
}
}
}
}
}
}
TensorCore for Conv
primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
allocate(Conv.wmma.accumulator, float32, [2048]) {
attr [Apad.shared: handle] "storage_scope" = "shared";
allocate(Apad.shared, float16, [12288]) {
attr [W.shared: handle] "storage_scope" = "shared";
allocate(W.shared, float16, [12288]) {
attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
allocate(Apad.shared.wmma.matrix_a, float16, [512]) {
attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
allocate(W.shared.wmma.matrix_b, float16, [1024]) {
attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
for (n.c.init: int32, 0, 2) {
for (o.c.init: int32, 0, 4) {
eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), float32(0)], handle, "intrin", 0))
}
}
for (ic.outer: int32, 0, 8) {
for (kh: int32, 0, 3) {
for (ax2: int32, 0, 3) {
for (ax3: int32, 0, 2) {
for (ax4.ax5.fused.outer: int32, 0, 8) {
attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and (1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)), load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
}
}
}
for (ax1: int32, 0, 3) {
for (ax2_1: int32, 0, 2) {
attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
}
}
for (ic.inner: int32, 0, 2) {
for (kw: int32, 0, 3) {
for (ax0: int32, 0, 2) {
eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
for (ax3_1: int32, 0, 4) {
eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
for (n.c: int32, 0, 2) {
for (o.c: int32, 0, 4) {
eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c)], handle, "intrin", 0))
}
}
}
}
}
}
for (n.inner: int32, 0, 2) {
for (o.inner: int32, 0, 4) {
eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin", 0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
}
}
}
}
}
}
}
}
So far we have a text printer for relay. which allows us to print an IRModule into text format. On the TIR side, we still relies on the ReprPrinter.
This is issue is for upgrading the text printer so that we can print an IRModule that include PrimFunc(tir::Function in the upstream) as a text format. This will help us to enhance the demo experience.
Possible Design Points
Ideally we want to land a version in the mainline in about two to three weeks. @spectrometerHBH please see if it is possible for you and @Hzfengsy to coordinate a format and land a version, then we can pull back to the tensorIR