mvcisback / py-aiger

py-aiger: A python library for manipulating sequential and combinatorial circuits encoded using `and` & `inverter` gates (AIGs).
MIT License
41 stars 9 forks source link

Very slow AIG walking #135

Open masinag opened 3 weeks ago

masinag commented 3 weeks ago

Hi, I am using py-aiger to perform some AIG manipulation. In particular, I am trying to convert AIG to PySMT formulas. I have found an example of AIG waking in https://github.com/mvcisback/py-aiger/blob/f50171549781e0910cd5aec25dd8582a9d219b84/aiger/writer.py#L41

I have tried to adapt this to my scenario, but I have noticed this gets very slow on medium-big instances. And I mean very slow, like hours instead of seconds. E.g. https://github.com/yogevshalmon/allsat-circuits/blob/b57c2d6cba244460008dc6400beef2604a720c24/benchmarks/random_aig/large_cir_or/bench1/bench1.aag

It seems to me that the bottleneck is somewhere in aiger.common.dfs function, likely operations on sets of nodes. I suppose that this can be due to the computation of hash for nodes (generated by @attr.frozen), which traverses the whole subgraph each time, for each node.

I attach code to replicate the issue.

import aiger
import funcy as fn

def gates(circ):
    gg = []
    count = 0

    class NodeAlg:
        def __init__(self, lit: int):
            self.lit = lit

        @fn.memoize
        def __and__(self, other):
            nonlocal count
            nonlocal gg
            count += 1
            new = NodeAlg(count << 1)
            right, left = sorted([self.lit, other.lit])
            gg.append((new.lit, left, right))
            return new

        @fn.memoize
        def __invert__(self):
            return NodeAlg(self.lit ^ 1)

    def lift(obj) -> NodeAlg:
        if isinstance(obj, bool):
            return NodeAlg(int(obj))
        elif isinstance(obj, NodeAlg):
            return obj
        raise NotImplementedError

    start = 1
    inputs = {k: NodeAlg(i << 1) for i, k in enumerate(sorted(circ.inputs), start)}
    count += len(inputs)

    omap, _ = circ(inputs=inputs, lift=lift)

    return gg

def main():
    circ = aiger.to_aig("bench1.aag")
    gg = gates(circ)

    print(len(gg))

if __name__ == '__main__':
    main()
mvcisback commented 3 weeks ago

Hi @masinag ,

Thanks for reaching out. I can take a look sometime in the coming weeks.

I suspect you're right about the hash issue, it's been a bit of a wart for a while and one of the reasons the lazy API was initially developed -- although that won't help here.

I recommend looking at it with a tool like pyspy to get a flame graph is probably going to good to confirm.

https://github.com/benfred/py-spy

If you have a chance to take a look at the py-spy let me know (feel free to attach the output svg).

Supposing it is the hashing in common.dfs we can look at two solutions:

  1. accelerating hashing in general.
  2. re-writing common.dfs to avoid the hashing.

Option 1

For option 1, I would have thought this was solved by cache_hash.

https://github.com/mvcisback/py-aiger/blob/f50171549781e0910cd5aec25dd8582a9d219b84/aiger/aig.py#L53

Perhaps we're having a lot of hash collisions and being killed by equality checks? Eitherway it's strange worst case we'll need to manually introduce smarter hashing and caching.

Option 2

I think this is the easiest to code, but not a very satisfying solution. Essentially would could switch to checking if that exact node has already been emitted. This would be done perhaps as follows.

def dfs(circ):
    """Generates nodes via depth first traversal in pre-order."""
    emitted: set()
    stack = list(circ.cones | circ.latch_cones)

    while stack:
        node = stack.pop()

        if id(node) in emitted:
            continue

        remaining = [c for c in node.children if id(c) not in emitted]

        if len(remaining) == 0:
            yield node
            emitted.add(id(node))   # node -> id(node)
            continue

        stack.append(node)  # Add to emit after remaining children.
        stack.extend(remaining)
mvcisback commented 3 weeks ago

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

masinag commented 3 weeks ago

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

I don't think I understand what you mean. I don't see where I am hashing NodeAlg objects

masinag commented 3 weeks ago

About the proposed options, I can try to profile the execution with py-spy.

Option 2 seems an easy fix, but if the issue is really the hashing, speeding it up could improve performance in many other contexts. So it could be worth looking deeper into that.

mvcisback commented 3 weeks ago

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

I don't think I understand what you mean. I don't see where I am hashing NodeAlg objects

Err, actually ignore what I said.

masinag commented 1 week ago

Hi, I profiled the above code using py-spy. I stopped it after 3 hours of execution. This is the flame graph.

profile

It looks like the problem is the equality check between nodes, since most of the time is taken by the __eq__ function generated by attrs.

mvcisback commented 1 week ago

Thanks @masinag ! Could you share the SVG as well since it's interactive?

But that I suppose this makes sense. It's not the hash that's the problem, it's the equality check that happens after to check it wasn't a hash collision....

I will need to think about how to speed that up. Given that

https://github.com/mvcisback/py-aiger/blob/f50171549781e0910cd5aec25dd8582a9d219b84/aiger/common.py#L192

seems to be the bottleneck, it make be good to implement option 2 anyway.

Alternatively, I could do a breaking change and make equality the same as an id check. It's been a while since I really thought about what implications that would have, but it seems like a reasonable option.

mvcisback commented 1 week ago

@masinag if I could ask for a favor, if you have the bandwidth, could you regenerate the above flameplot with the change suggested in:

https://github.com/mvcisback/py-aiger/issues/135#issuecomment-2455686829

I suspect we'll see a huge speed improvement, but unfortunately I'm not able to run it myself for a while.

masinag commented 1 week ago

Thanks @masinag ! Could you share the SVG as well since it's interactive?

Sure, you can find it at https://github.com/user-attachments/assets/b11cc6ba-1966-4108-b447-e6d3997d51ba

masinag commented 1 week ago

@masinag if I could ask for a favor, if you have the bandwidth, could you regenerate the above flameplot with the change suggested in:

I've stopped it after 30 min.

profile2

(interactive version https://github.com/user-attachments/assets/e000cc4a-aaab-4274-9193-3db3fdc4dc17)

It looks like there is another call to __eq__ when nodes are used as keys for the mem dictionary. https://github.com/mvcisback/py-aiger/blob/f50171549781e0910cd5aec25dd8582a9d219b84/aiger/aig.py#L201

masinag commented 6 days ago

The code is much faster (milliseconds instead of hours!) if the AIG class https://github.com/mvcisback/py-aiger/blob/f50171549781e0910cd5aec25dd8582a9d219b84/aiger/aig.py#L169 is modified as follows:

@attr.frozen(repr=False)
class AIG:
    ...
    def __call__(self, inputs, latches=None, *, lift=None):
        """Evaluate AIG on inputs (and latches).
        If `latches` is `None` initial latch value is used.

        `lift` is an optional argument used to interpret constants
        (False, True) in some other Boolean algebra over (&, ~).

        - See py-aiger-bdd and py-aiger-cnf for examples.
        """
        if latches is None:
            latches = dict()

        if lift is None:
            lift = fn.identity
            and_, neg = op.and_, op.not_
        else:
            and_, neg = op.__and__, op.__invert__

        latchins = fn.merge(dict(self.latch2init), latches)
        # Remove latch inputs not used by self.
        latchins = fn.project(latchins, self.latches)

        latch_map = dict(self.latch_map)
        boundary = set(self.node_map.values()) | set(latch_map.values())

        store, prev, mem = {}, set(), {}

        for node_batch in self.__iter_nodes__():
            prev = set(mem.keys()) - prev
            mem = fn.project(mem, prev)  # Forget about unnecessary gates.

            for gate in node_batch:
                if isinstance(gate, Inverter):
                    mem[id(gate)] = neg(mem[id(gate.input)])
                elif isinstance(gate, AndGate):
                    mem[id(gate)] = and_(mem[id(gate.left)], mem[id(gate.right)])
                elif isinstance(gate, Input):
                    mem[id(gate)] = lift(inputs[gate.name])
                elif isinstance(gate, LatchIn):
                    mem[id(gate)] = lift(latchins[gate.name])
                elif isinstance(gate, ConstFalse):
                    mem[id(gate)] = lift(False)

                if gate in boundary:
                    store[id(gate)] = mem[id(gate)]  # Store for eventual output.

        outs = {out: store[id(gate)] for out, gate in self.node_map.items()}
        louts = {out: store[id(gate)] for out, gate in latch_map.items()}
        return outs, louts

Notice that I used id(gate) instead of gate as dictionary key.

profile4 (interactive version https://github.com/user-attachments/assets/2ec73a97-4541-4a7a-b82d-7e6374c9e838)

mvcisback commented 6 days ago

Amazing! This does suggest to me that the right solution might be to do a breaking change and make __eq__ work via id. This would have have the same effect (since after hashing, eq is called), but would be done throughout the codebase.

I'll think about it over the weekend to make sure there aren't any gotcha's to applying more widely.

masinag commented 6 days ago

Great! Btw from attrs' doc:

If you want hashing and equality by object identity: use @define(eq=False)