Open zzzDavid opened 3 years ago
Following-up on this issue -- is there a proper fix planned?
The local import/declaration work-around seem to have additional restrictions. For example:
def f1 (A, B):
# move this block of code inside the do function and it segfaults.
@hcl.def_([(10,), (10,), ()])
def comp(A, B, x):
with hcl.if_(A[x] > B[x]):
hcl.return_(A[x])
hcl.return_(B[x])
def do (A, B, x):
return comp (A, B, x)
hcl.update(B, lambda x: do(A, B, x), "f1")
A = hcl.placeholder((10,), "A")
B = hcl.placeholder((10,), "B")
s = hcl.create_schedule([A, B], func=f1, name="main")
print(hcl.lower(s))
f = hcl.build(s)
As is, the code works. But move the def inside the do function (i.e., as local as it can get) and it generates a segfault. So it looks like @def functions can't be at the top-level nor can it be at some lower/local-level (speculating that this is because it is getting defined within a compute api context). This makes it difficult to define @def_'ed modular building-blocks .
Our current support for @hcl.def is very preliminary. I plan to solve the top-level declaration issue first. For the inner- or lower-level declaration, you suspect correctly. Our original design thought is more close to c++, which also doesn't allow low-level function declaration (unless you use a lambda function). I will think about adding the low-level support later.
As this issue points out, using Python decorator to specify function outlining could lead to scope issue: when submodules defined in a different Python are imported at the global level, the generated outlined function (KernelDef
IR node) is not included in the schedule. Besides, adding decorators involves modifying the algorithm specification, which is not a decoupled customization.
Therefore, we propose a new API called s.outline()
to specify which stages to outline in a decoupled way. Since this API specifies function outlining on schedule, it won't have the scope issue.
.to
.The compiler uses the input stages to extract the subgraph to be outlined. When there’s one stage input, only the input stage is outlined as a function. For the outlined function, the compiler also infers its input and output arguments, builds the function body, and inserts a call
operation into the caller function.
2MM performs two matrix multiplications followed by an element-wise addition. It has three stages: out_AB
, out_ABC
, and E
.
We use s.outline()
to outline out_AB
and out_ABC
as two functions:
A = hcl.placeholder((P, Q), "A")
B = hcl.placeholder((Q, R), "B")
C = hcl.placeholder((R, S), "C")
D = hcl.placeholder((P, S), "D")
def kernel_2mm(A, B, C, D):
r = hcl.reduce_axis(0, Q, "r")
out_AB = hcl.compute(
(P, R),
lambda x, y: hcl.sum(A[x, r] * B[r, y], axis=r, dtype=dtype),
name="out_AB",
)
k = hcl.reduce_axis(0, R, "k")
out_ABC = hcl.compute(
(P, S),
lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], axis=k, dtype=dtype),
name="out_ABC",
)
E = hcl.compute(
D.shape,
lambda x, y: (out_ABC[x, y] + D[x, y]),
dtype=dtype,
name="E",
)
return E
s = hcl.create_schedule([A, B, C, D], kernel_2mm)
s.outline(kernel_2mm.out_AB)
s.outline(kernel_2mm.out_ABC)
The IR before outlining:
module {
func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
// Stage out_AB
%0 = memref.alloc() : memref<16x18xf32>
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 18 {
affine.for %arg6 = 0 to 22 {
...
} {loop_name = "r"}
} {loop_name = "y"}
} {loop_name = "x", stage_name = "out_AB"}
// Stage out_ABC
%1 = memref.alloc() : memref<16x24xf32>
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 24 {
affine.for %arg6 = 0 to 18 {
...
} {loop_name = "k"}
} {loop_name = "y"}
} {loop_name = "x", stage_name = "out_ABC"}
// Stage E
%2 = memref.alloc() : memref<16x24xf32>
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 24 {
...
} {loop_name = "y"}
} {loop_name = "x", stage_name = "E"}
return %2 : memref<16x24xf32>
}
}
The IR after outlining:
module {
func @Stage_out_AB(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg3: memref<16x18xf32>) ->() {
...
}
func @Stage_out_ABC(%arg0: memref<16x18xf32>, %arg1: memref<18x24xf32>, %arg3: memref<16x24xf32>) ->() {
...
}
func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
// Stage out_AB
%0 = memref.alloc() : memref<16x18xf32>
call @Stage_out_AB(%arg0, %arg1, %0)
// Stage out_ABC
%1 = memref.alloc() : memref<16x24xf32>
call @Stage_out_AB(%0, %arg2, %1)
// Stage E
%2 = memref.alloc() : memref<16x24xf32>
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 24 {
...
} {loop_name = "y"}
} {loop_name = "x", stage_name = "E"}
return %2 : memref<16x24xf32>
}
}
We use s.outline
to specify the subgraph we would like to outline as a function:
s = hcl.create_schedule([A, B, C, D], kernel_2mm)
s.outline(kernel_2mm.out_AB, kernel_2mm.out_ABC)
The IR after outlining:
module {
func @Stage_outAB_outABC(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg3: memref<16x18xf32>, %arg4: memref<18x24xf32>, %arg5: memref<16x24xf32>) ->() {
...
}
func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
// Stage out_AB and out_ABC
%0 = memref.alloc() : memref<16x18xf32>
%1 = memref.alloc() : memref<16x24xf32>
call @Stage_outAB_outABC(%arg0, %arg1, %0, %arg2, %1)
// Stage E
%2 = memref.alloc() : memref<16x24xf32>
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 24 {
...
} {loop_name = "y"}
} {loop_name = "x", stage_name = "E"}
return %2 : memref<16x24xf32>
}
}
Description
For function outlining with
@def_
decorator, a local import is required to correctly generate backend code. For HLS backends, the HeteroCL will still run, but we don't get outlined function in generated HLS code. For LLVM backend, the HeteroCL code will fail to run, and throw a segmentation fault, which is difficult to debug.Minimum Example
main.py
submodule.py
Running
python main.py
will get a segmentation fault for LLVM backend. If we comment out the global import and release the local import, we get correct result.Cause
When import is done in global scope, the submodule function definition is not run during
hcl.create_schedule
, so we don't haveKernelDef
stmt in the IR, only theCall
stmt.Proposal: we can add an IR pass to check if all
Call
stmts have correspondingKernelDef
to detect this problem.