google-deepmind / alphatensor

Apache License 2.0
2.69k stars 238 forks source link

Equations in text form for your convenience #3

Closed 99991 closed 2 years ago

99991 commented 2 years ago

Here is some code to generate the equations in text form so no one has to type out the characters from the JPGs in the paper. The indices are slightly different than in the paper (might have missed a transpose somewhere), but the result comes out correct, so it is probably fine. Note that it prints over 7 MB of data. It might be a good idea to pipe the output to a file first before looking at it: python3 print_equations.py > equations.txt

The individual equations could be simplified a bit more (e.q. by using SymPy) but I wanted to keep the dependencies to a minimum.

Make sure to run the code in the same directory as the files factorizations_f2.npz and factorizations_r.npz.

print_equations.py

import numpy as np
from ast import literal_eval as make_tuple

np.random.seed(0)

"""
The *.npz files contain a dict with keys like "(2,3,4)" and values containing
a list of matrices U, V and W. For example, for the 2-by-2 times 2-by-2 case,
we have the following matrices:

U =
[[ 0  1  1  0  1  1  0]
 [ 0  0 -1  1  0  0  0]
 [ 1  1  1  0  1  0  0]
 [-1 -1 -1  0  0  0  1]]

V =
[[0 0 0 0 1 1 0]
 [1 1 0 0 1 0 1]
 [0 1 1 1 1 0 0]
 [0 1 1 0 1 0 1]]

W =
[[ 0  0  0  1  0  1  0]
 [ 0 -1  0  0  1 -1 -1]
 [-1  1 -1 -1  0  0  0]
 [ 1  0  0  0  0  0  1]]

Each column of U is multiplied with the vectorized matrix A.
Likewise, Each column of V is multiplied with the vectorized matrix B.
The resulting vectors are multiplied pointwise and their product is
multiplied with W, which forms the entries of the product matrix C = A times B.
Also see the function `multiply` below.
"""

# There are two factorizations, one for useful numbers and one for mod 2 math.
for filename, mod in [
    ("factorizations_r.npz", None),
    ("factorizations_f2.npz", 2),
]:
    # Load the factorizations. Note that allow_pickle=True allows arbitrary
    # code execution. A JSON file would have been a better format choice
    # since nothing here is stored in NumPy format anyway.
    factorizations = dict(np.load(filename, allow_pickle=True))

    # Test each factorization
    for key, UVW in factorizations.items():
        U, V, W = map(np.array, UVW)

        m, k, n = make_tuple(key)

        print(f"\nMultiply {m}-by-{k} matrix A with {k}-by-{n} matrix B")
        if mod is not None:
            print(f"using mod {mod} arithmetic")
        print()

        # Check that shapes are correct
        assert m * k == U.shape[0]
        assert k * n == V.shape[0]
        assert m * n == W.shape[0]
        assert U.shape[1] == V.shape[1]
        assert U.shape[1] == W.shape[1]

        # Generate two random matrices for testing
        A = np.random.randint(10, size=(m, k))
        B = np.random.randint(10, size=(k, n))

        def multiply(A, B, U, V, W):
            # Multiply two matrices A and B using index matrices U, V and W
            a = A.ravel()
            b = B.ravel()

            tmp = (U.T @ a) * (V.T @ b)
            c = W @ tmp
            C = c.reshape(n, m).T

            return C

        # Multiply matrices
        C = multiply(A, B, U, V, W)

        # Check that result is correct, taking potential mod 2 into account
        if mod is None:
            assert np.allclose(C, A @ B)
        else:
            assert np.allclose(C % mod, (A @ B) % mod)

        def make_code(variables, factors):
            # Generate code like "(a11 + a21 - a22)"
            parts = []

            for variable, factor in zip(variables, factors):
                # Simplify +1 and -1 factors
                if factor == 1:
                    factor = " + "
                elif factor == -1:
                    factor = " - "
                elif factor < 0:
                    factor = f" {factor} * "
                elif factor > 0:
                    factor = f" + {factor} * "
                else:
                    continue

                parts.append(factor + variable)

            code = "".join(parts).lstrip(" +")

            if len(parts) > 1:
                code = "(" + code + ")"

            return code

        def make_variables(var, m, n):
            # Generate variables like a11, a12, a21, a22
            # or maybe a_1_1, a_1_2, a_2_1, a_2_2.
            # For larger matrices, we need a separator to avoid
            # confusing e.g. a_1_11 with a_11_1.
            separator = "_" if max(m, n, k) > 9 else ""
            return [f"{var}{separator}{i + 1}{separator}{j + 1}"
                for i in range(m) for j in range(n)]

        A_variables = make_variables("a", m, k)
        B_variables = make_variables("b", k, n)
        C_variables = make_variables("c", m, n)
        h_variables = [f"h{i + 1}" for i in range(U.shape[1])]

        lines = [
            ", ".join(A_variables) + " = A.ravel()",
            ", ".join(B_variables) + " = B.ravel()",
        ]

        # Generate code for computation of temporary vector
        for h, u, v in zip(h_variables, U.T, V.T):
            sa = make_code(A_variables, u)
            sb = make_code(B_variables, v)

            lines.append(f"{h} = {sa} * {sb}")

        # Generate code for computation
        for c, w in zip(C_variables, W):
            lines.append(f"{c} = " + make_code(h_variables, w).strip("()"))

        lines.append("C = np.array([" + ", ".join(C_variables) +
            f"]).reshape({n}, {m}).T")

        code = "\n".join(lines)

        print(code)

        # Verify that code generates the correct result
        exec(code)

        if mod is None:
            assert np.allclose(C, A @ B)
        else:
            assert np.allclose(C % mod, (A @ B) % mod)

For example, the generated code for general 2-by-2 times 2-by-2 matrix multiplication is


a11, a12, a21, a22 = A.ravel()
b11, b12, b21, b22 = B.ravel()
h1 = (a21 - a22) * b12
h2 = (a11 + a21 - a22) * (b12 + b21 + b22)
h3 = (a11 - a12 + a21 - a22) * (b21 + b22)
h4 = a12 * b21
h5 = (a11 + a21) * (b11 + b12 + b21 + b22)
h6 = a11 * b11
h7 = a22 * (b12 + b22)
c11 = h4 + h6
c12 = - h2 + h5 - h6 - h7
c21 = - h1 + h2 - h3 - h4
c22 = h1 + h7
C = np.array([c11, c12, c21, c22]).reshape(2, 2).T

For 4-by-4 times 4-by-4 matrix multiplication in $\mathbb {Z} _{2}$, i.e. when doing mod 2 math (missed that on first reading), the code is:

a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44 = A.ravel()
b11, b12, b13, b14, b21, b22, b23, b24, b31, b32, b33, b34, b41, b42, b43, b44 = B.ravel()
h1 = a13 * b31
h2 = (a13 + a22 + a23) * (b21 + b24 + b34)
h3 = (a13 + a21 + a23) * (b11 + b13 + b33)
h4 = (a13 + a23) * (b11 + b13 + b21 + b24 + b31 + b33 + b34)
h5 = a11 * b11
h6 = (a11 + a31) * (b11 + b12 + b14 + b21 + b24 + b31 + b32)
h7 = (a11 + a31 + a33) * (b12 + b31 + b32)
h8 = (a11 + a12 + a13 + a22 + a23 + a31 + a32) * (b21 + b24)
h9 = (a12 + a41 + a42) * (b11 + b13 + b23)
h10 = (a12 + a42 + a43) * (b22 + b31 + b32)
h11 = (a12 + a42) * (b11 + b13 + b21 + b22 + b23 + b31 + b32)
h12 = (a11 + a12 + a13 + a21 + a23 + a41 + a42) * (b11 + b13)
h13 = (a11 + a12 + a13 + a31 + a33 + a42 + a43) * (b31 + b32)
h14 = a41 * (b12 + b13 + b23 + b41 + b42)
h15 = (a14 + a41 + a44) * (b12 + b41 + b42)
h16 = (a14 + a44) * (b12 + b33 + b41 + b42 + b43)
h17 = (a11 + a31 + a32) * (b14 + b21 + b24)
h18 = (a14 + a32 + a34 + a41 + a44) * (b41 + b42)
h19 = (a14 + a32 + a34) * (b22 + b41 + b42)
h20 = (a14 + a34) * (b22 + b34 + b41 + b42 + b44)
h21 = a22 * (b23 + b24 + b34 + b41 + b43)
h22 = (a14 + a22 + a24) * (b23 + b41 + b43)
h23 = (a14 + a43 + a44) * (b33 + b41 + b43)
h24 = (a14 + a21 + a23 + a43 + a44) * b33
h25 = (a14 + a22 + a34 + a43) * (b22 + b34 + b41 + b43)
h26 = a33 * (b12 + b32 + b34 + b41 + b44)
h27 = (a14 + a24) * (b14 + b23 + b41 + b43 + b44)
h28 = (a14 + a21 + a24) * (b14 + b41 + b44)
h29 = (a14 + a32 + a34 + a42 + a43) * b22
h30 = (a14 + a22 + a24 + a43 + a44) * (b41 + b43)
h31 = a14 * b41
h32 = (a14 + a33 + a34) * (b34 + b41 + b44)
h33 = (a21 + a31 + a41) * (b12 + b13 + b14)
h34 = (a14 + a22 + a24 + a41 + a42) * b23
h35 = (a24 + a34 + a44) * (b42 + b43 + b44)
h36 = (a14 + a22 + a23 + a33 + a34) * b34
h37 = (a23 + a33 + a43) * (b32 + b33 + b34)
h38 = (a22 + a32 + a42) * (b22 + b23 + b24)
h39 = a12 * b21
h40 = (a14 + a21 + a24 + a33 + a34) * (b41 + b44)
h41 = a43 * (b22 + b32 + b33 + b41 + b43)
h42 = a21 * (b13 + b14 + b33 + b41 + b44)
h43 = (a14 + a21 + a24 + a31 + a32) * b14
h44 = (a14 + a24 + a32 + a41) * (b14 + b23 + b41 + b42)
h45 = a32 * (b14 + b22 + b24 + b41 + b42)
h46 = (a14 + a21 + a33 + a44) * (b12 + b33 + b41 + b44)
h47 = (a14 + a31 + a33 + a41 + a44) * b12
c11 = h1 + h5 + h31 + h39
c12 = h1 + h2 + h3 + h4 + h21 + h22 + h27 + h28 + h31 + h42
c13 = h5 + h6 + h7 + h17 + h19 + h20 + h26 + h31 + h32 + h45
c14 = h9 + h10 + h11 + h14 + h15 + h16 + h23 + h31 + h39 + h41
c21 = h1 + h7 + h10 + h13 + h15 + h18 + h19 + h29 + h31 + h47
c22 = h16 + h20 + h23 + h24 + h25 + h26 + h30 + h32 + h35 + h36 + h37 + h40 + h41 + h46
c23 = h15 + h18 + h19 + h20 + h26 + h31 + h32 + h47
c24 = h15 + h16 + h18 + h19 + h23 + h29 + h31 + h41
c31 = h3 + h5 + h9 + h12 + h22 + h23 + h24 + h30 + h31 + h34
c32 = h22 + h23 + h24 + h27 + h28 + h30 + h31 + h42
c33 = h14 + h15 + h16 + h18 + h27 + h28 + h33 + h35 + h40 + h42 + h43 + h44 + h46 + h47
c34 = h14 + h15 + h16 + h22 + h23 + h30 + h31 + h34
c41 = h2 + h8 + h17 + h28 + h31 + h32 + h36 + h39 + h40 + h43
c42 = h21 + h22 + h27 + h28 + h31 + h32 + h36 + h40
c43 = h19 + h20 + h28 + h31 + h32 + h40 + h43 + h45
c44 = h18 + h19 + h20 + h21 + h22 + h25 + h27 + h29 + h30 + h34 + h35 + h38 + h44 + h45
C = np.array([c11, c12, c13, c14, c21, c22, c23, c24, c31, c32, c33, c34, c41, c42, c43, c44]).reshape(4, 4).T

For the Matlab users: Note that .ravel() in Python flattens row-wise, not column-wise. The same goes for reshape.