thomasrolinger / chapel

a Productive Parallel Programming Language
https://chapel-lang.org
Other
0 stars 1 forks source link

More robust handling for inspector-executor and multiple call sites #33

Closed thomasrolinger closed 2 years ago

thomasrolinger commented 2 years ago

Right now, we check to see if a given forall that would potentially be optimized via the inspector-executor is in a function which has call sites that are not in an outer serial loop. For such call sites, we clone the function and update the call sites.

However, this does not address all cases. Consider the following:

proc x() {
   forall { ... }
}

proc y() {
  x();
}

proc A() {
   y();
}

proc B() {
   for i in 0..#5 {
      y();
   }
}
A();
B();

The function A() does not provide an outer loop for the forall, so it would not be a good idea to do the optimization there. But the function B() does have an outer loop. The problem is that both A() and B() are tied to the same call site for the forall, which is y(). So our current approach would find that there is a valid outer loop and not do any cloning.

thomasrolinger commented 2 years ago

Rather than using function cloning (i.e., static optimization) we can do things at runtime (i.e., dynamic optimization). We create an if-then-else structure around the entire forall loop like this:

if hasOuterLoop && !invalidStructure {
   if staticCheck {
      if doInspector {
         inspector loop
      }
      executor loop
   }
   else {
      original forall
   }
}
else {
   original forall
}

For a given function that some number of forall loops are in, we associative it with two flags: one that says whether we have hit an outer serial loop on the way to the function and one that says whether we've hit an invalid enclosing structure on the way to the function. Each function will be given a unique ID, which will be used when checking these flags.

Our static analysis runs after resolution, and after the write/modification analysis. For any function that has a still-valid forall loop, it will compute the call graph and look at each call path. For a given call path, we find the furthest enclosing serial loop and add a call right before that loop to turn on the hasOuterLoop flag for the function being evaluated. Likewise, we look for the furthest enclosing invalid structure (if there is one) and add a call to turn on the invalidStructure flag. Right after these loops, we reset the flags. So if we found the outer loop, we set the hasOuterLoop flag to true right before the loop. When we eventually call into the function with the forall, that flag will be set. Assuming we didn't have any invalid structures, that flag will be false, so we'll do the optimized loop. If we did find an invalid structure along the call path, we would set the invalidStructure flag to true and we would fall back to executing the original forall. Note that these two flags are always initialized to false and true, respectively. So if the forall ends up not being valid at compile time, the entire if-branch would be basically empty but we'd take the else branch anyways. I think we could just remove the if-then-else entirely if we wanted.

So here is the example from the first post, with an extra call to A() at the end:

proc x() {
   forall { ... }
}

proc y() {
  x();
}

proc A() {
   y();
}

proc B() {
   for i in 0..#5 {
      y();
   }
}
A();
B();
A();

This would get transformed into something like this:

proc x() {
   if hasOuterLoop && !invalidStructure {
      opt forall { ... }
   }
   else {
      original forall
   }
}

proc y() {
  x();
}

proc A() {
   y();
}

proc B() {
   hasOuterLoop = true;
   for i in 0..#5 {
      y();
   }
   hasOuterLoop = false;
}
A();
B();
A();

So when we first call A(), our two flags are false and true, respectively (default values). The call to A() leads to the call to y(), which then calls x(). The if-check will resolve to false, so we do the original loop (this is what we want). Then we do the call to B(), which will set the hasOuterLoop flag to true right before the valid outer loop. So when we get to x(), the if-check is true, so we do the optimized loop (as planned). Once we finish that outer loop in B(), we turn the outer loop flag back to false. So when we do that last call to A(), all the flags are set to their default, so we get the same behavior as that first call to A().

Furthermore, if x() had multiple forall loops inside of it, we just need a single set of hasOuterLoop and invalidStructure flags for the function. If we have an outer loop on the way to x() for some call path, all loops in x() will be valid w.r.t. that outer loop. Likewise, if we find an invalid structure on the path, all forall loops in the function are not invalid.

An interesting case if when x() has two forall loops, one that is like we see above but the other with a direct outer loop around it. We have two options: (1) don't put the hasOuterLoop check in the if, but just use the invalidStructure check, or (2) just do what we do above and add the call hasOuterLoop=true right before the outer loop. Both would accomplish the same thing, which is to say that we always have an outer loop for that forall. Option (1) eliminates the need to check for it though. I would probably opt to go with (1) as it wouldn't be difficult (all of it can be done/determined at normalization).

I think this is a much cleaner approach and avoids all the nastiness of function cloning. We still require the need to do call graph construction and analysis after resolution, but we only need to do it once for a given function that contains at least one optimized forall (no need to recompute it a bunch of times).

thomasrolinger commented 2 years ago

Working on the approach described above, which I will refer to as dynamic call path analysis. I am working off of a new branch for this in case things go south and we need an easy way to restart. It also allows me to save off the incremental steps. The branch is called dynamic-callpath-analysis

The first step is to create the Chapel module/code that will encapsulate two flags, one for the outer-loop check and one for the invalid-structure check. We can place these in the existing InspectorExecutor module that we have.

We will determine during normalization how many functions we have that contain at least one potentially valid forall loop and use that when creating our class that encapsulates the flags, each being an array of bools. The index used into the arrays will be the unique ID given to a function. We don't need to create multiple instances of this class, as it will be the same one used by all of the functions/loops in the program.

So what we do is keep track of the unique functions we see during normalization that have candidates, and assign them IDs. After we've finished looking at all of the forall loops, we insert a call within the InspectorExecutor module to update a global object that encapsulates the flags, telling it how many functions we have.

The global object will look something like this:

class DynamicCallPathFlags {
   var numFuncs : int;
   var D : domain(1) = {0..#numFuncs};
   var hasOuterLoop : [D] bool;
   var hasInvalidEnclosingStructure : [D] bool;
}

Then we will have our global instantiation of the class within the InspectorExecutor module:

// default/dummy value of 1 passed in so we can call the set/unset methods initially
var gDynamicCallPathFlags = new DynamicCallPathFlags(numFuncs=1);

Next, we have the functions we use to set/unset the flags specifically in gDynamicCallPathFlags:

proc setHasOuterLoop(funcID : int) {
   gDynamicCallPathFlags.hasOuterLoop[funcID] = true; 
}
proc unsetHasOuterLoop(funcID : int) { 
   gDynamicCallPathFlags.hasOuterLoop[funcID] = false; 
}
proc setHasInvalidEnclosingStructure(funcID : int) { 
   gDynamicCallPathFlags.hasInvalidEnclosingStructure[funcID] = true; 
}
proc unsetHasInvalidEnclosingStructure(funcID : int) { 
   gDynamicCallPathFlags.hasInvalidEnclosingStructure[funcID] = false; 
}

We also have a function to update the number of functions we have to consider. This is what the compiler inserts a call to after normalization is finished processing the loops:

proc setNumFuncsForDynamicCallPathFlags(numFuncs : int) {
   gDynamicCallPathFlags.numFuncs = numFuncs;
   gDynamicCallPathFlags.D = {0..#numFuncs};
   gDynamicCallPathFlags.hasOuterLoop = false;
   gDynamicCallPathFlags.hasInvalidEnclosingStructure = false;
  }

And then we will add the following calls to the above functions to the InspectorExecutor module to make sure they will be around after resolution:

setHasOuterLoop(0);
unsetHasOuterLoop(0);
setHasInvalidEnclosingStructure(0);
unsetHasInvalidEnclosingStructure(0);

Finally, once we know how many functions we have to worry about, we go back and add the following call:

setNumFuncsForDynamicCallPathFlags(numFuncs);

At the end of all of this, we will have a globally accessible instantiation of the DynamicCallPathFlags class called gDynamicCallPathFlags and we will have set it up so it has the correct number of flags for the functions we have determined to have potential candidates. We've also made dummy calls to all methods that we will need to call in post-resolution.

I've added this code to modules/internal/InspectorExecutor.chpl and ensured that we can indeed call the set/unset functions in a global manner (I did these by hand). I also added the set/unset methods to the well-known functions list: https://github.com/thomasrolinger/chapel/commit/988fe4888c0ab637c6d0b125c0bf17d8b1f4fa42

EDIT: forgot that we also need functions to get the flags for a given function ID. We use these when building up the if-then-else structure around the optimized loops. Since both flags are also checked together, we can just have a simple isCallPathValid() method, which returns true if we had an outer loop and we did not have any invalid enclosing structures. I added this in this commit: https://github.com/thomasrolinger/chapel/commit/56c1a604c7039e17e66b3f74c2123ef847ba1e06

thomasrolinger commented 2 years ago

The remaining work will now be within the compiler itself.

First, let's create the necessary data structures to accomplish the following during normalization:

  1. Count the number of functions that have at least one potentially valid forall to be optimized
  2. Create a mapping of said FnSymbols from (1) to unique IDs that we will use when calling into the Chapel code.

We need to have such a mapping of FnSymbols to IDs so that our code transformations during normalization can insert the correct calls to check the appropriate flags. After that though, I don't believe we need to maintain the mapping. We can just store the ID within forall->optInfo and make sure that follows the forall around via any cloning. During post resolution when we do the actually call path analysis for a given forall, we extract this ID and use that in any calls to the set/unset methods we insert.

So we're looking at something like this:

// top of forallOptimizations.cpp
std::map<FnSymbol *, int> gFunctionsWithCandidatesIE;
// in forall->optInfo in ForallStmt.h
int functionID;

When we process a given forall during normalization, we see if it is in a function that is NOT the module init function. If it is in a "real" function, we look up the FnSymbol in the map. If we don't find it, we add it and increment a global running count to associate with that function. If it does exist, we just grab the associated ID that is stored with it in the map. We then set forall->optInfo.functionID to that ID. We ensure that this ID is copied when we clone. We then use this ID when we create the calls to check the flags.

Here is the commit that sets up the ID (also removed the prior cloning stuff that will no longer be relevant): https://github.com/thomasrolinger/chapel/commit/0c81c5c5892a12be5b429b2cf32f325a79dfcc4e

thomasrolinger commented 2 years ago

The next step is creating the if-then-else structure. This should be straightforward to do; just generate the call to isCallPathValid(), passing in the forall->optInfo.functionID. We put that into the conditional and check for gTrue. The then-block will be the entire "thing" have have created right now with the inspector/executor loop. The else-block will be the original forall (we make another clone of it).

We also add the call to setNumFuncsForDynamicCallPathFlags(numFuncs) after we've processed all the forall loops. We insert that at the end of the InspectorExecutor module (added a pragma to the module so we can find it easily).

This is the commit that does that. I tested that the code structure is valid by forcing isCallPathValid to always return true, which would make all of our tests run as usual. I also tested cases where the forall is not in a function, so we do not need to create the if-then-else structure. Finally, I fixed a bug with how we look for candidates accesses, specifically when we have something like A[B[C[i]]]: https://github.com/thomasrolinger/chapel/commit/3b4e779fdd31589bbf7617bab5597d8d313b8b33

thomasrolinger commented 2 years ago

Everything has been implemented and it seems to be working on all small and large test cases: Dynamic call path analysis implemented

thomasrolinger commented 2 years ago

Now the question is whether we can easily adopt this approach to handle invalid writes to array/domains that would invalidate the optimization. Such analysis would be tied to individual forall loops, not just the function it is within.

I've added data structure changes to support this, but the entire analysis itself is not worth doing right now. Even if implemented, we will not be able to handle all cases (see below). Since it will take some effort, I'd rather work on something else right now.

Data Structure Changes

For the data structure changes: we can do the below without actually changing how we currently do the write analysis (i.e., we aren't looking at things during call path analysis yet).

  1. Data structure changes: Change the modifications map to be std::map<Symbol *, std::map<CallExpr *, bool>>. The keys are the symbols which are written to. This also includes any aliases we find.
  2. We also want to store the aliases we find and map them to wherever they came from. We can do this with a std::map<Symbol *, std::set<Symbol *>>. The keys in this map are symbols which have aliases to them (e.g., for ref x = y, we store y as the key and x as part of the corresponding set). If we find more refs "made from" y, we add them to the set.
  3. After we find the aliases in (2), we can clean up the data structure so the chains are compressed. For example, if we have ref FOO = A[0..2], our approach will find that a call_tmp is an alias to A (this is the thing that holds the slice of A) and that FOO is an alias to that call_tmp. What we really want to store is that FOO and call_tmp are aliases to A. That way, when we do our call path analysis with a given Symbol, we can get all of the aliases in one look-up, rather than trying to traverse the map.

I also fixed a bug in our detectAliasesInCall function when it tried to find domain slices. I was lazy and didn't do it quite right, so it was finding more aliases than existed. Didn't seem to cause errors since those aliases were a bunch of call_tmps that weren't written to. Now we actually check for the right thing and the aliases we find seem more reasonable.

Here is the commit: https://github.com/thomasrolinger/chapel/commit/d244abe5d198e8bd43359e0fdf2590ccf93b1aa4

Case We Cannot Handle

Regarding the case that will not work for us, consider this code:

proc func(X,Y,Z) {
    forall ...
}
proc foo(A,B,C,D,E,F) {
    func(A,B,C);
    func(C,E,F);
}

for ... {
    foo(A1,B1,C1,D1,E1,F1);
    B2 = ..
    foo(A2, B2, ....);
}

The problem is when/where we turn on the flag that says we have an invalid write. If we do it right around that 2nd call to foo, then we will NOT do the optimized forall for both calls to func, even though only the first one is affected by the invalid write. If we instead put the flag set/unset around the first call to func, then the first call to foo will not do the optimized forall when it could/should.

thomasrolinger commented 2 years ago

I am keeping this post here in case I want to return to this at some point

  1. Before we do this, we can still do our usual invalid write analysis for anything within the forall. Those invalid writes will completely cancel out the optimization. So we want to do that first to avoid redundant checks later.
  2. Once we do (1), we can move on to the actual call path analysis. Ideally, we'd merge this together with the existing analysis we have for outer loops. However, that is performed on a per-function basis while the analysis we need to do is per-forall. Furthermore, we need access to the relevant forall loop's modifications vectors we built up. Therefore, we will do the call path analysis within the executeAccessCallPostAnalysis() function.
  3. First, build up an initial set of the Symbols we want to track. We want two sets: one that contains globals (not passed into the function) and one that contains things passed into the function or defined in the function). The set of globals will not change as we process call paths. We start with the original Symbols we have (A, B, forallIteratorDom, etc.). We check whether they are formals to the function; if they are not then add them to the global set, else add them to the other set. We then look at any Symbol in the modifications map whose defPoint is within the function that the forall is within (will include formals); those get added to the non-global set. By looking at all things in the modification map whose defPoint is within the function, we will get any aliases to the Symbols that are relevant. Note that if the forall is not in a function, then the two sets will be identical.
  4. If the forall is not in a function (it is in something like chpl_gen_main or the module init), we can simply do the approach we already have that finds the closest outer loop and then checks for modifications. We change the approach we have to take in the set of Symbols` from (3) and looks them up in the modifications map, and we do not attempt to go across call sites. That way, we can reuse this function in future steps.
  5. If the forall is in a function, check for an outer loop within the function. If we find one, look for any modifications to the Symbols we pulled out in (3) and see if any are within the outer loop. For every modification we find that is within the outer loop, we mark the corresponding bool for the modification (CallExpr) to false. If we don't find any, then there is nothing we do (all the bools are set to true already). In either case, we stop/return since we found and processed an outer loop.
  6. If we do not find an outer loop in the function, then we specifically look at the Symbols in the non-global set from (3) and see whether they are formals to the function. If they are, we log the index position. This is so we can associate them with the actual argument at a call site.
  7. Going off of (6), we look at each call site of the function. For a given call site, pull out the actual argument that corresponds to the indices we logged in (6) for the formals; add those Symbols to the non-global set we have of relevant Symbols from (3). We then add to that set by looking for any Symbol from the modifications map whose defPoint is within the current function. Since we are using a set, we'll ignore any duplicates. It is true that some Symbols in this set are no longer relevant, like the original formals to the function whose call site we are processing. We could attempt to prune them from the set at this point, but I'll opt to just leave them as it won't matter w.r.t. correctness.
  8. Now we repeat steps (4) - (7). This is a recursive process, where we stop/recurse back when we find an outer loop on a call-path. When we return back from a recursive call, we want to remove any Symbols we added to the non-global set. That is so the next call site can be processed and won't be "tainted" by irrelevant Symbols from the other call site. So prior to calling this function recursively, we need to keep a map/vector of the things we just added so we can remove them when we return back.
thomasrolinger commented 2 years ago

I think we can resolve the above issue of call path analysis for invalid writes by placing the relevant flag(s) within the array/domain of interest. When we see an invalid write, we call the set function on that array/domain that is written to. When we get to the forall, we check the flags on the arrays/domains of interest. If any are set to true, we don't do the optimization. We turn the flag back off AFTER the outer loop. This is because while doing iterations of the outer loop, any optimization that uses that array/domain is compromised (it would have to re-run the inspector each time). NOTE: this is definitely not true; see the example at the end of this comment. This "ends" once we finish the loop. Then, a call could be made to the function using that array/domain (assuming it is in another loop). So something like this:

proc func(X,Y,Z) {
    forall ...
}
proc foo(A,B,C,D,E,F) {
    func(A,B,C);
    func(C,E,F);
}

for ... {
    foo(A1,B1,C1,D1,E1,F1);
    B2 = ..
    setInvalidWriteFlag(B2);
    foo(A2, B2, ....);
}
unsetInvalidWriteFlag(B2);

for ... {
   foo(A2, B2, ....); // this is OK 
}

One nasty issue I do see here is that perhaps B2 is an alias that was made within the for loop, so we don't have access to it outside of the for. We may be able to get around this because we do have a mapping of aliases to the actual Symbol. So we find the Symbol that is reachable from outside of the for loop.

But a more severe problem is that we could have this:

proc func(X,Y,Z) {
    forall ...
}
proc foo(A,B,C,D,E,F) {
    func(A,B,C);
    func(C,E,F);
    for i in ... {
       func(A,B,C);
    }
}

for ... {
    foo(A1,B1,C1,D1,E1,F1);
    B2 = ..
    setInvalidWriteFlag(B2);
    foo(A2, B2, ....);
}
unsetInvalidWriteFlag(B2);

When we set B2's flag and then call foo(), we will NOT do the optimization for the first call to func(). The 2nd call is not affected, so we do it. But the third is in its own outer loop. In this case, the B2 modification is valid since it is outside of the closet outer loop for the call site. But we still have the flag set so we wouldn't do the optimization.

With that said, I don't believe this is a complete solution.