Open tfogal opened 7 months ago
Assigning to @rdspring1. Not intending to imply he should do everything here, but because we talked about this I want to be sure he sees + edits the above and ensures I did not forget anything.
@tfogal Your summary looks good!
Here are some other items from my notes:
python_test/test_python_frontend.py
, since debugging the fusion cache through the CI is cumbersome. The test would artificially lower the cache limit to three and test how the cache behaves when we reach the cache limit.vector<FusionCache*>
structure, we need to decouple the vector's indices from the fusion's id.Here are some other items from my notes:
Thanks Ryan! I edited these into the main issue, to make it easy for someone to pick this up w/o following all our comments.
We did not implement an LRU Cache eviction policy since it was not needed when we first started using the Python Interface and had other things to do. Therefore, it was designed to simply assert when the number of node entries reaches the max. The max was set at 8192 to pass our tests. 8192, in reality, is huge for s single model. The cases where we get in trouble are when we run back-to-back tests and the cache does not get reset.
Note, a lot of the testing is done at the C++ level. https://github.com/NVIDIA/Fuser/tree/main/csrc/python_frontend/test
When a user constructs too many fusions today, the fusion cache overflows and nvFuser throws up its hands and gives up. Any user could theoretically hit this, but we acutely feel the pain in CI, where an extensive test suite has grown the fusion cache to its maximum size. The temporary workaround of growing the cache (#1702) is not a true solution.
The salient elements of the cache are a pair of linked data structures: a trie of the fusion elements themselves (i.e. the ops) and then a paired
vector<FusionSchedule*>
(FusionCache::fusions_
) that has the actual cached element. Each node in the trie corresponds to a set of ops "up to" that node. For example, if there are two entries in the cache, corresponding to mul-add-reduction and mul-add-division, then the trie is conceptually something like this:where there are
FusionSchedule
s attached to the trie at thereduction
anddivision
nodes (in the implementation, these are actuallyRecordType::End
nodes). Every suchEnd
node will then have an associated "index" (TrieNode::fusion_id
) that indexes into theFusionCache::fusions_
array.A more straightforward implementation would be a hash table for the set of fusions. The hash table implementation may have additional overheads due to collisions, however: if we hash a fusion to, say, 42, we can't assume that the fusion at element 42 is the same as what we need---we need to then check each op in the fusion and see if it matches with each op in element 42 of the table, because multiple fusions might hash to 42. Such a check is O(n) in the list of ops of the fusion. Traversing the trie is O(n) in the list of ops of the fusion, but we know immediately when we get to the end whether we've found a match or whether it's not available.
Some things to investigate:
std::unordered_map
is likely to be, maybe this approach isn't so bad. Prototype some examples with a lookup via both a trie and anunordered_map
and see what kind of perf difference we are really talking about.TrieNode::fusion_id_
andFusionDefinition::id()
) to be the index of the fusion inFusionCache::fusions_
?FusionSchedule
End
records) this might leave dangling paths in the trie that never lead to anEnd
record. We would need to garbage collect these dangling paths in addition to removing theEnd
records.vector<FusionCache*>
data structure, could we have theTrieNode
contain theFusionCache*
directly? The pointer would only be non-null for aRecordType::End
TrieNode
, but we're only wasting 1 extra pointer per node, so this doesn't seem like much.unique_ptr
for theFusionCache
, and then passing around the "id" (index intoFusionCache::fusions_
) to clients that need to reference a cache entry. Should we instead use ashared_ptr
and then make clients that need to reference it (such asFusionDefinition::id()
) hold theweak_ptr
? This might more naturally model the kind of sharing that's actually occurring.vector<FusionCache*>
, we should at least decouple its "id" from the index in the vector.python_test/test_python_frontend.py
to artificially lower the cache limit and test that way---testing through CI is quite cumbersome!