Qiskit / qiskit

Qiskit is an open-source SDK for working with quantum computers at the level of extended quantum circuits, operators, and primitives.
https://www.ibm.com/quantum/qiskit
Apache License 2.0
5.19k stars 2.36k forks source link

Faster hashing for Pauli operators. #13355

Open MarcDrudis opened 2 days ago

MarcDrudis commented 2 days ago

What should we add?

Currently the hashing of a Pauli operator is computed as hash(self.to_label()). It is basically asking for the label of the Pauli operator which is an expensive computation to run. Instead building it directly from self.x, self.z,self.phase seems more reasonable. I did a quick test where I do the hashing with int.from_bytes(np.packbits([self.x, self.z]+[self.phase])) and it was about 50 times faster.

jakelishman commented 2 days ago

Pauli instances are mutable, so really they shouldn't have a __hash__ defined at all (though it'd be totally reasonable to provide a "hash key" function, to tbf is to_label). I think you might have made a slight copy/paste error in your example here, because that's trying to bit-pack something that's not a flat array.

Fwiw, I feel like the most major story here is about how slow the label computation is. It really shouldn't be expensive to compute the Pauli label, even from Python space. I threw this together quickly just now, and it's somewhere in the region of 33x faster than the current implementation of Pauli.to_label for large-ish Paulis:

CHARS = np.array([ord("I"), ord("X"), ord("Z"), ord("Y")], dtype=np.uint8)

def label(pauli):
    index = (pauli.z << 1)
    index += pauli.x
    ascii_label = CHARS[index[::-1]].data.tobytes()
    phase_label = ("", "-i", "-", "i")[(ascii_label.count(b'Y') - pauli._phase.item()) % 4]
    return phase_label + ascii_label.decode("ascii")

(I couldn't entirely compare to your function because of the ragged array thing.)

Would us switching the to_label implementation over to something like this be enough to work for you? It'd have the benefit of improving to_label, without the potential pain of changing how the hash keys worked - it probably shouldn't matter, but it could have weird effects on performance for dictionaries, because the hash of an int in Python just takes the low few bits (depending on dict capacity), which makes me worry that most large Paulis in a dict will have identities in the last few places, and so the dict lookups will be unnecessarily costly.

MarcDrudis commented 1 day ago
  1. What I tried on my code was actually int.from_bytes(np.packbits([self.x, self.z])) which was working just fine (I think numpy takes care of concatenating the bits on its own).
  2. I don't know if your point on dictionary performance is for or against the use of .to_label(). I will just assume it is against (if not just ignore this point). I may not be fully understanding how the hash function works in python, but I thought the hash of an integer is the same integer. In that case, the binary encoding of x and y (plus a couple of bits for the phase) is really the most efficient encoding of the Pauli operator. Converting it to a string at all seems like an unecessary step. For instance, for the Pauli("-IXYZ") we have:
    print(hash('IXYZ')) # -2097615359577420896
    int.from_bytes(np.packbits([[False,True,True,False]+[False,False,True,True]+[True,False]])) #25472

    The numbers I get for the second encoding are consistently smaller.

  3. I did not think about it, but changing the hash of Paulis now might break dictionary lookups for user stored data.

That said, if you think to_label is better I think a 30x speed improvement would already be great. Thank you for your time :)

Cryoris commented 1 day ago

A faster to_label would also speed up the Pauli evolutions 🙂

jakelishman commented 1 day ago
  1. Ah right, fair enough. For hashing that's generally fine, although it does mean that there are four distinct Paulis that all hash to the same thing (which I would imagine would be fine).

For point 2: I'm really just being really conservative about the hashing thing with integers here - in the end, I suspect that the collision-resolution in dict lookups would make fine. The "size" of the hashed integer doesn't matter a huge amount, though for large Paulis you'll find the string hash is significantly smaller than the pack_bits one, because string hashes are returned as a 64-bit integer, whereas pack_bits is unbounded.

To explain a bit more: you're absolutely right that the hash of an integer is (nearly) itself in Python (actually, for negative numbers and once the numbers get beyond sys.hash_info.modulus it changes a bit). Strings on the other hand have a hash that's mixed in with the value of PYTHONHASHSEED (by default, randomised each time the Python interpreter initialises) to avoid DoS attacks by crafted inputs (important for web-server-y applications, not a concern for research code most of the time). So the hash is the same, but the way a hash-map works doesn't use the entire hash result - you first take the hash, and cut it down somehow so that every hash maps to a very small number of possibilities. The number of possibilities depends on how many entries your dict currently holds. Python does that simply by taking the low few bits of the hash value. If your hashes don't have a particular correlation in their low bits (and Python is betting, from experience, that most of the time you don't), then that's really really fast.

If we construct some dictionaries with integers that are engineered to share the majority of their bits (but still all be technically unique integrs), we can force the hashmap conflict resolution to work in overdrive:

num = (2**20) * 2 // 3 - 2
size = 200
shifts = list(range(129))

lookup = []
for shift in shifts:
    keys = list(
        range(
            (1 << size) - ((num // 2) << shift),
            (1 << size) + ((num // 2) << shift),
            1 << shift,
        )
    )
    hashmap = dict.fromkeys(keys)
    loop = %timeit -o -n 10 for key in keys: pass
    contains = %timeit -o -n 10 for key in keys: key in hashmap
    lookup.append(contains.best - loop.best)

(Beware that'll take a while to run.) It's constructing dictionaries with integer keys that are evenly spaced around 2 ** 200, but then differ by 1, 2, 4, 8, etc up to 2 ** 128, so as the shift gets larger, the keys have more of their low bits as 0. After, I'm just timing how long it takes to lookup every key in the dictionary, correcting for the time taken to loop through the list and do nothing.

Image

You can see in the output that different numbers of leading zeros cause wildly different behaviour in the dict lookup, at various points in the structured hash keys, even though all the returns from __hash__. On my system (and almost certainly yours), the hash modulus is the Mersenne prime $2^{61} - 1$, and you can see the period in the graph at 61 zeros.

I won't pretend to be able to explain everything about the graph, and I already spent way too long investigating haha, but the point I was trying to make is that: using integers with a structure to them as hash keys can be dangerous for performance. A large dict of large Paulis is likely to have structure to them, which makes me nervous about using an symplectic bit-based key for their hashing in Python, since there's only limited mixing in of the higher bits. The difference between the biggest peaks and the biggest troughs in my graph is about a 20x performance hit.

Python's hashing algorithm for strings is such that a small change in a string has a cascading effect, so for Paulis, which are likely to have low Hamming distances between them in practice (most Paulis have a low weight), the strings just feel slightly more reliable to me.