Closed thomasrolinger closed 2 years ago
Rather than using function cloning (i.e., static optimization) we can do things at runtime (i.e., dynamic optimization). We create an if-then-else structure around the entire forall
loop like this:
if hasOuterLoop && !invalidStructure {
if staticCheck {
if doInspector {
inspector loop
}
executor loop
}
else {
original forall
}
}
else {
original forall
}
For a given function that some number of forall
loops are in, we associative it with two flags: one that says whether we have hit an outer serial loop on the way to the function and one that says whether we've hit an invalid enclosing structure on the way to the function. Each function will be given a unique ID, which will be used when checking these flags.
Our static analysis runs after resolution, and after the write/modification analysis. For any function that has a still-valid forall
loop, it will compute the call graph and look at each call path. For a given call path, we find the furthest enclosing serial loop and add a call right before that loop to turn on the hasOuterLoop
flag for the function being evaluated. Likewise, we look for the furthest enclosing invalid structure (if there is one) and add a call to turn on the invalidStructure
flag. Right after these loops, we reset the flags. So if we found the outer loop, we set the hasOuterLoop
flag to true right before the loop. When we eventually call into the function with the forall
, that flag will be set. Assuming we didn't have any invalid structures, that flag will be false, so we'll do the optimized loop. If we did find an invalid structure along the call path, we would set the invalidStructure
flag to true and we would fall back to executing the original forall. Note that these two flags are always initialized to false and true, respectively. So if the forall
ends up not being valid at compile time, the entire if-branch would be basically empty but we'd take the else branch anyways. I think we could just remove the if-then-else entirely if we wanted.
So here is the example from the first post, with an extra call to A()
at the end:
proc x() {
forall { ... }
}
proc y() {
x();
}
proc A() {
y();
}
proc B() {
for i in 0..#5 {
y();
}
}
A();
B();
A();
This would get transformed into something like this:
proc x() {
if hasOuterLoop && !invalidStructure {
opt forall { ... }
}
else {
original forall
}
}
proc y() {
x();
}
proc A() {
y();
}
proc B() {
hasOuterLoop = true;
for i in 0..#5 {
y();
}
hasOuterLoop = false;
}
A();
B();
A();
So when we first call A()
, our two flags are false and true, respectively (default values). The call to A()
leads to the call to y()
, which then calls x()
. The if-check will resolve to false, so we do the original loop (this is what we want). Then we do the call to B()
, which will set the hasOuterLoop
flag to true right before the valid outer loop. So when we get to x()
, the if-check is true, so we do the optimized loop (as planned). Once we finish that outer loop in B()
, we turn the outer loop flag back to false. So when we do that last call to A()
, all the flags are set to their default, so we get the same behavior as that first call to A()
.
Furthermore, if x()
had multiple forall
loops inside of it, we just need a single set of hasOuterLoop
and invalidStructure
flags for the function. If we have an outer loop on the way to x()
for some call path, all loops in x()
will be valid w.r.t. that outer loop. Likewise, if we find an invalid structure on the path, all forall
loops in the function are not invalid.
An interesting case if when x()
has two forall
loops, one that is like we see above but the other with a direct outer loop around it. We have two options: (1) don't put the hasOuterLoop
check in the if, but just use the invalidStructure
check, or (2) just do what we do above and add the call hasOuterLoop=true
right before the outer loop. Both would accomplish the same thing, which is to say that we always have an outer loop for that forall
. Option (1) eliminates the need to check for it though. I would probably opt to go with (1) as it wouldn't be difficult (all of it can be done/determined at normalization).
I think this is a much cleaner approach and avoids all the nastiness of function cloning. We still require the need to do call graph construction and analysis after resolution, but we only need to do it once for a given function that contains at least one optimized forall
(no need to recompute it a bunch of times).
Working on the approach described above, which I will refer to as dynamic call path analysis. I am working off of a new branch for this in case things go south and we need an easy way to restart. It also allows me to save off the incremental steps. The branch is called dynamic-callpath-analysis
The first step is to create the Chapel module/code that will encapsulate two flags, one for the outer-loop check and one for the invalid-structure check. We can place these in the existing InspectorExecutor
module that we have.
We will determine during normalization how many functions we have that contain at least one potentially valid forall
loop and use that when creating our class that encapsulates the flags, each being an array of bool
s. The index used into the arrays will be the unique ID given to a function. We don't need to create multiple instances of this class, as it will be the same one used by all of the functions/loops in the program.
So what we do is keep track of the unique functions we see during normalization that have candidates, and assign them IDs. After we've finished looking at all of the forall
loops, we insert a call within the InspectorExecutor
module to update a global object that encapsulates the flags, telling it how many functions we have.
The global object will look something like this:
class DynamicCallPathFlags {
var numFuncs : int;
var D : domain(1) = {0..#numFuncs};
var hasOuterLoop : [D] bool;
var hasInvalidEnclosingStructure : [D] bool;
}
Then we will have our global instantiation of the class within the InspectorExecutor
module:
// default/dummy value of 1 passed in so we can call the set/unset methods initially
var gDynamicCallPathFlags = new DynamicCallPathFlags(numFuncs=1);
Next, we have the functions we use to set/unset the flags specifically in gDynamicCallPathFlags
:
proc setHasOuterLoop(funcID : int) {
gDynamicCallPathFlags.hasOuterLoop[funcID] = true;
}
proc unsetHasOuterLoop(funcID : int) {
gDynamicCallPathFlags.hasOuterLoop[funcID] = false;
}
proc setHasInvalidEnclosingStructure(funcID : int) {
gDynamicCallPathFlags.hasInvalidEnclosingStructure[funcID] = true;
}
proc unsetHasInvalidEnclosingStructure(funcID : int) {
gDynamicCallPathFlags.hasInvalidEnclosingStructure[funcID] = false;
}
We also have a function to update the number of functions we have to consider. This is what the compiler inserts a call to after normalization is finished processing the loops:
proc setNumFuncsForDynamicCallPathFlags(numFuncs : int) {
gDynamicCallPathFlags.numFuncs = numFuncs;
gDynamicCallPathFlags.D = {0..#numFuncs};
gDynamicCallPathFlags.hasOuterLoop = false;
gDynamicCallPathFlags.hasInvalidEnclosingStructure = false;
}
And then we will add the following calls to the above functions to the InspectorExecutor
module to make sure they will be around after resolution:
setHasOuterLoop(0);
unsetHasOuterLoop(0);
setHasInvalidEnclosingStructure(0);
unsetHasInvalidEnclosingStructure(0);
Finally, once we know how many functions we have to worry about, we go back and add the following call:
setNumFuncsForDynamicCallPathFlags(numFuncs);
At the end of all of this, we will have a globally accessible instantiation of the DynamicCallPathFlags
class called gDynamicCallPathFlags
and we will have set it up so it has the correct number of flags for the functions we have determined to have potential candidates. We've also made dummy calls to all methods that we will need to call in post-resolution.
I've added this code to modules/internal/InspectorExecutor.chpl
and ensured that we can indeed call the set/unset functions in a global manner (I did these by hand). I also added the set/unset methods to the well-known functions list: https://github.com/thomasrolinger/chapel/commit/988fe4888c0ab637c6d0b125c0bf17d8b1f4fa42
EDIT: forgot that we also need functions to get the flags for a given function ID. We use these when building up the if-then-else structure around the optimized loops. Since both flags are also checked together, we can just have a simple isCallPathValid()
method, which returns true if we had an outer loop and we did not have any invalid enclosing structures. I added this in this commit: https://github.com/thomasrolinger/chapel/commit/56c1a604c7039e17e66b3f74c2123ef847ba1e06
The remaining work will now be within the compiler itself.
First, let's create the necessary data structures to accomplish the following during normalization:
forall
to be optimizedFnSymbol
s from (1) to unique IDs that we will use when calling into the Chapel code.We need to have such a mapping of FnSymbol
s to IDs so that our code transformations during normalization can insert the correct calls to check the appropriate flags. After that though, I don't believe we need to maintain the mapping. We can just store the ID within forall->optInfo
and make sure that follows the forall
around via any cloning. During post resolution when we do the actually call path analysis for a given forall
, we extract this ID and use that in any calls to the set/unset methods we insert.
So we're looking at something like this:
// top of forallOptimizations.cpp
std::map<FnSymbol *, int> gFunctionsWithCandidatesIE;
// in forall->optInfo in ForallStmt.h
int functionID;
When we process a given forall
during normalization, we see if it is in a function that is NOT the module init function. If it is in a "real" function, we look up the FnSymbol
in the map. If we don't find it, we add it and increment a global running count to associate with that function. If it does exist, we just grab the associated ID that is stored with it in the map. We then set forall->optInfo.functionID
to that ID. We ensure that this ID is copied when we clone. We then use this ID when we create the calls to check the flags.
Here is the commit that sets up the ID (also removed the prior cloning stuff that will no longer be relevant): https://github.com/thomasrolinger/chapel/commit/0c81c5c5892a12be5b429b2cf32f325a79dfcc4e
The next step is creating the if-then-else structure. This should be straightforward to do; just generate the call to isCallPathValid()
, passing in the forall->optInfo.functionID
. We put that into the conditional and check for gTrue
. The then-block will be the entire "thing" have have created right now with the inspector/executor loop. The else-block will be the original forall
(we make another clone of it).
We also add the call to setNumFuncsForDynamicCallPathFlags(numFuncs)
after we've processed all the forall
loops. We insert that at the end of the InspectorExecutor
module (added a pragma to the module so we can find it easily).
This is the commit that does that. I tested that the code structure is valid by forcing isCallPathValid
to always return true, which would make all of our tests run as usual. I also tested cases where the forall
is not in a function, so we do not need to create the if-then-else structure. Finally, I fixed a bug with how we look for candidates accesses, specifically when we have something like A[B[C[i]]]
: https://github.com/thomasrolinger/chapel/commit/3b4e779fdd31589bbf7617bab5597d8d313b8b33
Everything has been implemented and it seems to be working on all small and large test cases: Dynamic call path analysis implemented
Now the question is whether we can easily adopt this approach to handle invalid writes to array/domains that would invalidate the optimization. Such analysis would be tied to individual forall
loops, not just the function it is within.
I've added data structure changes to support this, but the entire analysis itself is not worth doing right now. Even if implemented, we will not be able to handle all cases (see below). Since it will take some effort, I'd rather work on something else right now.
For the data structure changes: we can do the below without actually changing how we currently do the write analysis (i.e., we aren't looking at things during call path analysis yet).
modifications
map to be std::map<Symbol *, std::map<CallExpr *, bool>>
. The keys are the symbols which are written to. This also includes any aliases we find. std::map<Symbol *, std::set<Symbol *>>
. The keys in this map are symbols which have aliases to them (e.g., for ref x = y
, we store y
as the key and x
as part of the corresponding set). If we find more ref
s "made from" y
, we add them to the set. ref FOO = A[0..2]
, our approach will find that a call_tmp
is an alias to A
(this is the thing that holds the slice of A
) and that FOO
is an alias to that call_tmp
. What we really want to store is that FOO
and call_tmp
are aliases to A
. That way, when we do our call path analysis with a given Symbol
, we can get all of the aliases in one look-up, rather than trying to traverse the map.I also fixed a bug in our detectAliasesInCall
function when it tried to find domain slices. I was lazy and didn't do it quite right, so it was finding more aliases than existed. Didn't seem to cause errors since those aliases were a bunch of call_tmps that weren't written to. Now we actually check for the right thing and the aliases we find seem more reasonable.
Here is the commit: https://github.com/thomasrolinger/chapel/commit/d244abe5d198e8bd43359e0fdf2590ccf93b1aa4
Regarding the case that will not work for us, consider this code:
proc func(X,Y,Z) {
forall ...
}
proc foo(A,B,C,D,E,F) {
func(A,B,C);
func(C,E,F);
}
for ... {
foo(A1,B1,C1,D1,E1,F1);
B2 = ..
foo(A2, B2, ....);
}
The problem is when/where we turn on the flag that says we have an invalid write. If we do it right around that 2nd call to foo
, then we will NOT do the optimized forall
for both calls to func
, even though only the first one is affected by the invalid write. If we instead put the flag set/unset around the first call to func
, then the first call to foo
will not do the optimized forall
when it could/should.
I am keeping this post here in case I want to return to this at some point
forall
. Those invalid writes will completely cancel out the optimization. So we want to do that first to avoid redundant checks later.forall
. Furthermore, we need access to the relevant forall
loop's modifications vectors we built up. Therefore, we will do the call path analysis within the executeAccessCallPostAnalysis()
function.Symbol
s we want to track. We want two sets: one that contains globals (not passed into the function) and one that contains things passed into the function or defined in the function). The set of globals will not change as we process call paths. We start with the original Symbol
s we have (A
, B
, forallIteratorDom
, etc.). We check whether they are formals to the function; if they are not then add them to the global set, else add them to the other set. We then look at any Symbol
in the modifications map whose defPoint
is within the function that the forall
is within (will include formals); those get added to the non-global set. By looking at all things in the modification map whose defPoint
is within the function, we will get any aliases to the Symbol
s that are relevant. Note that if the forall
is not in a function, then the two sets will be identical.forall
is not in a function (it is in something like chpl_gen_main
or the module init), we can simply do the approach we already have that finds the closest outer loop and then checks for modifications. We change the approach we have to take in the set of Symbol
s` from (3) and looks them up in the modifications map, and we do not attempt to go across call sites. That way, we can reuse this function in future steps.forall
is in a function, check for an outer loop within the function. If we find one, look for any modifications to the Symbol
s we pulled out in (3) and see if any are within the outer loop. For every modification we find that is within the outer loop, we mark the corresponding bool
for the modification (CallExpr
) to false
. If we don't find any, then there is nothing we do (all the bool
s are set to true already). In either case, we stop/return since we found and processed an outer loop.Symbol
s in the non-global set from (3) and see whether they are formals to the function. If they are, we log the index position. This is so we can associate them with the actual argument at a call site.Symbol
s to the non-global set we have of relevant Symbol
s from (3). We then add to that set by looking for any Symbol
from the modifications map whose defPoint
is within the current function. Since we are using a set, we'll ignore any duplicates. It is true that some Symbol
s in this set are no longer relevant, like the original formals to the function whose call site we are processing. We could attempt to prune them from the set at this point, but I'll opt to just leave them as it won't matter w.r.t. correctness.Symbol
s we added to the non-global set. That is so the next call site can be processed and won't be "tainted" by irrelevant Symbol
s from the other call site. So prior to calling this function recursively, we need to keep a map/vector of the things we just added so we can remove them when we return back.I think we can resolve the above issue of call path analysis for invalid writes by placing the relevant flag(s) within the array/domain of interest. When we see an invalid write, we call the set function on that array/domain that is written to. When we get to the forall
, we check the flags on the arrays/domains of interest. If any are set to true, we don't do the optimization. We turn the flag back off AFTER the outer loop. This is because while doing iterations of the outer loop, any optimization that uses that array/domain is compromised (it would have to re-run the inspector each time). NOTE: this is definitely not true; see the example at the end of this comment. This "ends" once we finish the loop. Then, a call could be made to the function using that array/domain (assuming it is in another loop). So something like this:
proc func(X,Y,Z) {
forall ...
}
proc foo(A,B,C,D,E,F) {
func(A,B,C);
func(C,E,F);
}
for ... {
foo(A1,B1,C1,D1,E1,F1);
B2 = ..
setInvalidWriteFlag(B2);
foo(A2, B2, ....);
}
unsetInvalidWriteFlag(B2);
for ... {
foo(A2, B2, ....); // this is OK
}
One nasty issue I do see here is that perhaps B2
is an alias that was made within the for
loop, so we don't have access to it outside of the for
. We may be able to get around this because we do have a mapping of aliases to the actual Symbol
. So we find the Symbol
that is reachable from outside of the for
loop.
But a more severe problem is that we could have this:
proc func(X,Y,Z) {
forall ...
}
proc foo(A,B,C,D,E,F) {
func(A,B,C);
func(C,E,F);
for i in ... {
func(A,B,C);
}
}
for ... {
foo(A1,B1,C1,D1,E1,F1);
B2 = ..
setInvalidWriteFlag(B2);
foo(A2, B2, ....);
}
unsetInvalidWriteFlag(B2);
When we set B2
's flag and then call foo()
, we will NOT do the optimization for the first call to func()
. The 2nd call is not affected, so we do it. But the third is in its own outer loop. In this case, the B2
modification is valid since it is outside of the closet outer loop for the call site. But we still have the flag set so we wouldn't do the optimization.
With that said, I don't believe this is a complete solution.
Right now, we check to see if a given
forall
that would potentially be optimized via the inspector-executor is in a function which has call sites that are not in an outer serial loop. For such call sites, we clone the function and update the call sites.However, this does not address all cases. Consider the following:
The function
A()
does not provide an outer loop for theforall
, so it would not be a good idea to do the optimization there. But the functionB()
does have an outer loop. The problem is that bothA()
andB()
are tied to the same call site for theforall
, which isy()
. So our current approach would find that there is a valid outer loop and not do any cloning.