wcharczuk / go-incr

A go implementation of Jane Street's "incremental".
MIT License
33 stars 1 forks source link

Performance significantly deteriorates with increasing number of observed nodes in the graph #12

Closed selasijean closed 7 months ago

selasijean commented 7 months ago

Noticed in the app that when I made a batch of calculations when the graph is empty, calculations were super fast. With subsequent batches of calculations, times increased significantly very fast. (Formulas in our models are so gnarly that a batch of calculations can easily increase observed nodes by >5000)

Did some profiling and noticed a significant amount of time is spent in observeNodes over time

Screenshot 2024-02-11 at 9 42 10 AM

Profile seemed to indicate that map access in maybeAddObservedNode was a bottleneck so I wrote some tests to confirm how the size of the number of observed set affected perf and the results show it has significant impact.

go maps back to haunt us 😭

    w := func(bs *incr.BindScope, t int) incr.Incr[*int] {
        key := fmt.Sprintf("w-%d", t)
        if _, ok := cache[key]; ok {
            return incr.WithinBindScope(bs, cache[key])
        }

        r := incr.Bind(incr.Root(), incr.Var(incr.Root(), "fakeformula"), func(bs *incr.BindScope, formula string) incr.Incr[*int] {
            out := 1
            return incr.Return(bs, &out)
        })

        r.Node().SetLabel(fmt.Sprintf("w(%d)", t))
        cache[key] = r
        return r
    }

        //usual definition for month_of_runway

    t.Run("number of nodes and perf", func(t *testing.T) {
        graph := incr.New(incr.GraphMaxRecomputeHeapHeight(1024))
        max_t := 50

        // baseline
        start := time.Now()

        for i := 0; i < max_t; i++ {
            o := months_of_runway(incr.Root(), i)
            incr.Observe(incr.Root(), graph, o)
        }

        graph.Stabilize(ctx)
        elapsed := time.Since(start)
        fmt.Printf(fmt.Sprintf("Baseline calculation of months of runway for t= %d to %d took %s \n", 0, max_t, elapsed))

        maxMultiplier := 10
        for k := 1; k <= maxMultiplier; k++ {

            graph = incr.New(incr.GraphMaxRecomputeHeapHeight(1024))

            num := 5000 * k
            // beef up node count in graph
            for i := 0; i < num; i++ {
                o := w(incr.Root(), i)
                incr.Observe(incr.Root(), graph, o)
            }

            start = time.Now()

            for i := 0; i < max_t; i++ {
                o := months_of_runway(incr.Root(), i)
                incr.Observe(incr.Root(), graph, o)
            }

            graph.Stabilize(ctx)
            elapsed = time.Since(start)
            fmt.Printf("Calculating months of runway for t= %d to %d took %s when prior_count(observed nodes) >%d\n", 0, max_t, elapsed, num)
        }

        assert.Fail(t, "deliberate")
    })

Results

Baseline calculation of months of runway for t= 0 to 50 took 138.916167ms 
Calculating months of runway for t= 0 to 50 took 398.881042ms when prior_count(observed nodes) >5000
Calculating months of runway for t= 0 to 50 took 638.370625ms when prior_count(observed nodes) >10000
Calculating months of runway for t= 0 to 50 took 879.998375ms when prior_count(observed nodes) >15000
Calculating months of runway for t= 0 to 50 took 1.254103583s when prior_count(observed nodes) >20000
Calculating months of runway for t= 0 to 50 took 1.574210291s when prior_count(observed nodes) >25000
Calculating months of runway for t= 0 to 50 took 1.94521675s when prior_count(observed nodes) >30000
Calculating months of runway for t= 0 to 50 took 2.450169875s when prior_count(observed nodes) >35000
Calculating months of runway for t= 0 to 50 took 2.97191425s when prior_count(observed nodes) >40000
Calculating months of runway for t= 0 to 50 took 3.452771583s when prior_count(observed nodes) >45000
Calculating months of runway for t= 0 to 50 took 3.676525041s when prior_count(observed nodes) >50000
wcharczuk commented 7 months ago

A couple things; first, can you pull latest main and try running on that? There have been a few fixes for how "changes" propagate through nodes that may help a little.

But I also want to push back a little on the diagnosis here. The issue isn't that the map access is slow (it's actually pretty fast!) it's just that we're doing a lot of map accesses. I would use pprof / flamegraphs generally as a heuristic, but not as a diagnosis tool; it's more important to understand why we're spending so much time accessing the map, or more pointedly, why we're doing it so often.

In the test above, my sense is you're spending so much time observing nodes because you're creating a lot of nodes, and specifically a lot of observers!!

Calculating months of runway for t= 0 to 50 took 1.978492542s when prior_count(observed nodes) >40000
Graph node count=217592, observer count=40050

So at 217592 nodes total, we're spending a non-trivial amount of time "observing" new nodes. That makes sense, we have a lot of observers!

Remember that any time we "observe" a new node, we traverse the tree up from the observed node to all its parents and then all those parent's parents and so on. Adding observers is as a result pretty expensive, an we should only do it when we know we need to get the value of a given node.

I'm going to close this for now, but feel free to re-open based on if the following are happening:

selasijean commented 7 months ago

@wcharczuk

My intuition into thinking there might be some optimization to be made here with this example is that:

  1. w is deliberately defined such that there is no connection between the subgraph created by w nodes and the subgraph created by nodes created from month_of_runway
  2. if 1 is true, it just seemed odd that the traversal from observing the new nodes (and there's not a lot of them -- only 50) created from months_of_runway was expensive and the only cause I could think of was that the cost of the lookup in maps got worse with growing size of the observed set.

Then I asked ChatGPT why runtime.mapaccess2 is slow and it replied: (😂 maybe it's hallucinating but I thought it was interesting):

The performance of `runtime.mapaccess2` in Go, as with any internal runtime function, depends on various factors, including the specifics of the operation being performed, the characteristics of the data, and the overall design of the map structure itself. If you're finding `runtime.mapaccess2` to be slow, there are several potential reasons why this might be the case:

1. **Hash Collisions**: One key aspect of map performance is the quality of the hash function used for the keys. If many keys hash to the same value, this leads to collisions, which can slow down access times because the runtime has to search through all keys in the same hash bucket to find the correct one.

2. **Large Maps**: As the size of a map grows, the cost of accessing elements can increase, especially if the map's capacity isn't adequately sized from the start. This can lead to more frequent rehashing (resizing and redistributing of elements) as elements are added, which can impact performance during access if it coincides with these operations.

3. **Poor Key Distribution**: If the keys in the map are not well-distributed, certain buckets within the map's internal structure may become overloaded, leading to longer search times when accessing these buckets.

4. **Concurrency Issues**: Although map access itself is not inherently slow due to concurrency (since maps in Go are not safe for concurrent use without synchronization), the way a map is used in a concurrent context can affect performance. For example, if you're using locks or other synchronization mechanisms to safely access a map from multiple goroutines, the overhead of acquiring and releasing locks can significantly impact performance.

5. **Compiler and Runtime Optimizations**: The Go compiler and runtime are continuously optimized for better performance across versions. However, specific cases or patterns of map usage might not benefit as much from these optimizations, leading to slower performance in some scenarios.

6. **Operational Complexity**: The internal implementation of maps involves a trade-off between time and space complexity. The Go runtime optimizes for general use cases, but specific access patterns (like frequently accessing non-existent keys, which still requires a hash and bucket search) might be less efficient.

When experiencing slow performance with map accesses, consider the following optimizations:

- **Pre-sizing Maps**: If you know the number of elements a map will hold, initialize it with adequate capacity using `make(map[KeyType]ValueType, size)` to minimize rehashing.
- **Key Choice**: Use keys that distribute well, avoiding poor hash function performance. Custom types used as keys should have a well-defined and efficient `Hash` method.
- **Profile Your Application**: Use Go's profiling tools to understand where the bottlenecks are. This can help identify if `runtime.mapaccess2` is genuinely the source of slowness or if it's symptomatic of broader performance issues.
- **Review Map Usage**: Sometimes, the way a map is used isn't optimal for performance. For example, excessive checking for the existence of keys or iterating over maps can lead to performance issues.

It's worth noting that performance considerations can be highly context-dependent. In some cases, the inherent limitations of the map data structure might make it unsuitable for a particular use case, and an alternative data structure or storage mechanism might be more appropriate.

These are the questions I have:

I've been reading https://lord.io/spreadsheets/. The bottom part of this article was interesting and hence why I thought this test was interesting to share.

A lot has changed so maybe all of this taken care of.

wcharczuk commented 7 months ago

does the observing mark all intermediate parents node as observed and consequently adds them to the set?

Yup, this is exactly how it works, and why observation (and unobservation) are expensive. The reason to do it this way is to know specifically what parts of a graph are reachable up from a given observer, and instead of computing that over and over as we stabilize, you just memoize that as a list of observers on each node. FWIW the Jane Street implementation does the same thing, just with a linked list.

are we equating "being observed" to "being necessary"?

Yes that is generally how it works, but really it just boils down to "this node is associated with this graph", which is only changed when we first observe a node (setting a graph reference on the node), and unobserved fully (where we nil out the graph reference if there are no remaining observers).

The observers list acts as a ref count, and specifically we keept it as a list so we know more accurately how to deal with situations where a node is reachable by the same observer from multiple "paths" through the graph, versus just as an integer count.

selasijean commented 7 months ago

Hmm I see. The node explosion from the kind of formulas folks write is looking crazy and if we mark up all intermediate nodes, this is tough for perf. I have no idea on how to deal with it yet on my end but I'll think about it some more.

wcharczuk commented 7 months ago

Similarly, I'm looking into how to avoid extra work when unobserving with nested binds, I think there may be some semi-related issues where if we unobserve a node that has a lot of nesting it will effectively propagate the unobserve multiple times up the tree.

Specifically unobserving happens in the above case when we "bind" and the returned node has a new id versus the existing bound node, which generally happens if you construct a node in the bind function with, e.g. Map(...) instead of returning a cached node.

wcharczuk commented 7 months ago

Actually, I think I have an idea; if the scope is considered necessary because the parent bind is necessary, we can skip a bunch of observation / unobservation.

wcharczuk commented 7 months ago
Baseline calculation of months of runway for t= 0 to 50 took 76.012541ms
Calculating months of runway for t= 0 to 50 took 113.893208ms when prior_count(observed nodes) >5000
Graph node count=35908, observer count=5050
Calculating months of runway for t= 0 to 50 took 120.99375ms when prior_count(observed nodes) >10000
Graph node count=62232, observer count=10050
Calculating months of runway for t= 0 to 50 took 125.844709ms when prior_count(observed nodes) >15000
Graph node count=88556, observer count=15050
Calculating months of runway for t= 0 to 50 took 143.752542ms when prior_count(observed nodes) >20000
Graph node count=114880, observer count=20050
Calculating months of runway for t= 0 to 50 took 162.505917ms when prior_count(observed nodes) >25000
Graph node count=141204, observer count=25050
Calculating months of runway for t= 0 to 50 took 188.521542ms when prior_count(observed nodes) >30000
Graph node count=167528, observer count=30050
Calculating months of runway for t= 0 to 50 took 204.272917ms when prior_count(observed nodes) >35000
Graph node count=193852, observer count=35050
Calculating months of runway for t= 0 to 50 took 230.052584ms when prior_count(observed nodes) >40000
Graph node count=220176, observer count=40050
Calculating months of runway for t= 0 to 50 took 275.18025ms when prior_count(observed nodes) >45000
Graph node count=246500, observer count=45050
Calculating months of runway for t= 0 to 50 took 291.620917ms when prior_count(observed nodes) >50000
Graph node count=272824, observer count=50050

it's not great obviously but it's an improvement

wcharczuk commented 7 months ago

If you pull latest main, there is quite a bit that's changed and I'd expect a few weird scenarios to pop up, but it should be better.

EDIT: There are a lot of weird scenarios, it'll take some time to untangle.

wcharczuk commented 7 months ago

I've reverted to a version that solves the specific situation above, but basically nukes parallel stabilization throughput. I think this is fine for now, will revisit in a few weeks.