Open j2kun opened 3 months ago
I read through this one and it's very straightforward and well written.
They decompose each shift required into its constituent bits and shift by the corresponding powers of 2 (e.g., a shift of 5 is 101 requires a shift by 1 and by 4). Up to an exception mentioned later, they always do the rotations in order of bits (shift by 1, then by 2, then by 4, etc.). This can cause conflicts if two intermediate shifted values end up colliding, adding together two values that shouldn't be added together. To avoid this, they construct a graph coloring problem by adding edges when there would be such a conflict, color the graph with the minimal possible colors (using a standard solver) and each color corresponds to a separate copy of the total set of rotations. E.g., three colors means three sets of all log(N) rotations. Then they can run the different sets in parallel. Some minor simplifications are included, like skipping a rotation when all assigned values have that bit zero.
The exception mentioned above is that they try changing the order of the rotations applied (e.g., maybe do rot(4) then rot(16) then rot(2) then rot (8), etc.), and they did something like construct the permutation network 10 times for 10 random orderings of the power-of-two rotations, and take the one that produces the graph with the smallest chromatic number.
This is still a logarithmic-size construction, like Benes networks, but they claim empirically it does much better for small permutations than HElib's log-size construction.
They have a C++ implementation here: https://github.com/jellevos/perm_map_circuits
I also think it's not necessarily better than the baseline approach in all cases, so we'd want to do both and pick the best for each input program.
I added an implementation of the basic graph coloring approach in https://github.com/j2kun/rotate-solver/blob/main/vos_vos_erkin.py (though I only added one simple test, may still be buggy). I computed the graph to partition the rotation indices, and the rest of the needed code is quite simple: convert a list of index groups to an set of iterated mask+shift+add operations and a final addition.
I also wrote a short blog post explaining the algorithm https://www.jeremykun.com/2024/09/02/shift-networks/
Efficient Circuits for Permuting and Mapping Packed Values Across Leveled Homomorphic Ciphertexts. Jelle Vos, Daniël Vos, and Zekeriya Erkin. 2022. http://dx.doi.org/10.1007/978-3-031-17140-6_20
This would be an improvement over the baseline in https://github.com/google/heir/issues/914