HarryR / ethsnarks

A toolkit for viable zk-SNARKS on Ethereum, Web, Mobile and Desktop
GNU Lesser General Public License v3.0
240 stars 57 forks source link

Implement Poseidon permutation #124

Closed HarryR closed 4 years ago

HarryR commented 5 years ago

Paper: https://eprint.iacr.org/2019/458.pdf

Implementations:

This is more EVM friendly than Rescue which requires two exponentiations per round.

HarryR commented 5 years ago

The following code shows that Lagrange interpolation is possible using bad/insecure parameters.

However, inverse Lagrange interpolation isn't possible, even with bad parameters.

This means that Lagrange interpolation over a finite field isn't a variable attack, because you can't use it to solve(y = f(x), x) with a fixed y value.

from ethsnarks.shamirspoly import lagrange, inverse_lagrange
from ethsnarks.poseidon import poseidon_params, poseidon
from ethsnarks.field import FQ
from math import log2, floor
p = 2423
#ShittyParams = poseidon_params(p, 3, 4, 4, b'poseidon', 5, security_target=floor(log2(p)))
ShittyParams = poseidon_params(p, 3, 2, 1, b'poseidon', 5, security_target=floor(log2(p)))

fixed = 123
x = 600
y = poseidon([fixed, x], params=ShittyParams)

points = [(_, poseidon([fixed, _], params=ShittyParams)) for _ in range(1, 500)]
points_fq = [(FQ(a,p), FQ(b,p)) for a, b in points]
inv_points_fq = [(b,a) for a, b in points_fq]

z = int(lagrange(points_fq, FQ(x, p)))
v = poseidon([fixed, x], params=ShittyParams)
print(w, z, v)
assert v == w
# This proves that Lagrange interpolation *is possible*
# albeit with shitty parameters
HarryR commented 5 years ago

I then demonstrate that a specific answer can be found (in a reasonable amount of time), but with two free variables a solution is very unlikely to be found:

This attacks the first round of the Poseidon permutation, where you want a resultant state (after mixing) to equal a known state, this is equivalent to finding a collision:

t = 3
R_F = 4
R_P = 4
e = 5
p = 2423

F = GF(p)
from random import randint
constants = [randint(1, p-1) for _ in range(R_F + R_P)]
cmatrix = [randint(1, p-1) for _ in range(t*2)]
M = [[int(1/(F(cmatrix[i]) - F(cmatrix[t+j]))) for j in range(t)] for i in range(t)]

x_0 = randint(1, p-1)
x_1 = var('x_1', domain='positive')
x_2 = var('x_2', domain='positive')

state_0 = [x_0, x_1, x_2]
ark_0 = [_ + constants[0] for _ in state_0]
sbox_0 = [_^5 for _ in ark_0]
mix_0 = [ sum([M[i][j] * _ for j, _ in enumerate(sbox_0)]) for i in range(len(M)) ]

def poseidon_full_round(state):
    after_ark = [_ + constants[0] for _ in state]
    after_sbox = [_^5 for _ in after_ark]
    return [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M)) ]

# Example...
#
# sage: [_ % p for _ in poseidon_full_round([663, 352, 128])]
# [2, 2417, 988]
#

# Then prove that, given 663 and 352, we can find `128`
hack_x = var('hack_x')
k = poseidon_full_round([663, 352, hack_x])
solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)
# And it works ;)
# sage: solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)
# [(128,)]

# Given a fixed variable, which isn't 663, we want to find two other inputs which result in the same state
hack_y = var('hack_y', domain='positive')
hack_z = var('hack_z', domain='positive')
k = poseidon_full_round([221, hack_y, hack_z])
solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)
# Which takes a long time, and a huge amount of memory,
# even for a small value of `p`

# We can break this problem down into something similar
# Which is simply attacking the matrix transform for the mix round
# Given that we can work backwards from the sbox for the two free
# variables, find the e'th root then subtract the round constant
# this simplifies the expression.
# We then perform the sbox only on our fixed variable, and feed it into the result
fixed_term = int((F(221) + constants[0]) ^ 5)
hack_y = var('hack_y', domain='positive')
hack_z = var('hack_z', domain='positive')
after_sbox = [fixed_term, hack_y, hack_z]
k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M)) ]
# The resulting polynomial is much simpler, see:
# sage: k
#   [401*hack_y + 1685*hack_z + 3034582,
#    1903*hack_y + 2092*hack_z + 2939286,
#    867*hack_y + 687*hack_z + 524128]
# We can then try and solve this, after the manual optimisation
# Which is a system of 3 degree 1 polynomials
solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)
# Again... this lots of time and memory

# So, we can try another approach, using the matrix solve_left and solve_right methods
desired_result = vector(F, [2, 2417, 988])
matrix_M = matrix(F, M)
solved_input = matrix_M.solve_right(desired_result)
assert matrix_M * solved_input == desired_result
# However, this doesn't let us pick a specific value
# Which we need to cause a collision

# If we're allowed to use `0`, and restrict us to one free variable
after_sbox = [fixed_term, 0, hack_z]
k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M)) ]
solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)
# This proves that a solution isn't possible with only 1 free variable

# Is it possible to exploit some property of the matrix?
cmatrix_var = [var('c_%d' % (_,)) for _ in range(t*2)]
M_var = [[1/(cmatrix_var[i] - cmatrix_var[t+j]) for j in range(t)] for i in range(t)]
after_sbox = [fixed_term, hack_y, hack_z]
k = [ sum([M_var[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M_var)) ]
#sage: k
#[hack_y/(c_0 - c_4) + hack_z/(c_0 - c_5) + 1489/(c_0 - c_3),
# hack_y/(c_1 - c_4) + hack_z/(c_1 - c_5) + 1489/(c_1 - c_3),
# hack_y/(c_2 - c_4) + hack_z/(c_2 - c_5) + 1489/(c_2 - c_3)]
# However, we're unable to solve this equation using `solve_mod`
# fail: solve_mod([k[0] == 2, k[1] == 2417, k[2] == 988], p)

after_sbox = [((663 + constants[0]) ^ 5) % p,
              ((352 + constants[0]) ^ 5) % p,
              ((123 + constants[0]) ^ 5) % p]
actual_result = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) % p for i in range(len(M)) ]

# we can substitute any one element with a variable
# and it will solve it
after_sbox[1] = hack_x
k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M)) ]
solve_mod([k[_] == actual_result[_] for _ in range(len(actual_result))], p)

# but, can we modify one of the parameters, and it'll still solve it with one free variable
after_sbox = [((663 + constants[0]) ^ 5) % p,
              ((352 + constants[0]) ^ 5) % p,
              ((123 + constants[0]) ^ 5) % p]
after_sbox[1] = hack_x

for derp in range(100):
    after_sbox[0] = derp
    k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) for i in range(len(M)) ]
    print(solve_mod([k[_] == actual_result[_] for _ in range(len(actual_result))], p))

# what about a brute-force approach? just to see if there any solutions
after_sbox = [((663 + constants[0]) ^ 5) % p,
              ((352 + constants[0]) ^ 5) % p,
              ((123 + constants[0]) ^ 5) % p]
for a in range(p):
    after_sbox[0] = a
    for b in range(p):     
        after_sbox[1] = b
        k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) % p for i in range(len(M)) ]
        if k == actual_result:
            print(a, b)
# This shows there's only one solution with 2 inputs, and state width = 3

# But, if we say t=4, and choose a smaller prime (for example)
p = 109
t = 4
F = GF(p)
constants = [randint(1, p-1) for _ in range(R_F + R_P)]
cmatrix = [randint(1, p-1) for _ in range(t*2)]
M = [[int(1/(F(cmatrix[i]) - F(cmatrix[t+j]))) for j in range(t)] for i in range(t)]

after_sbox = [((randint(1,p-1) + constants[0]) ^ 5) % p,
              ((randint(1,p-1) + constants[0]) ^ 5) % p,
              ((randint(1,p-1) + constants[0]) ^ 5) % p,
              ((randint(1,p-1) + constants[0]) ^ 5) % p]
actual_result = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) % p for i in range(len(M)) ]
print(after_sbox)
for a in range(p):
    after_sbox[0] = a
    for b in range(p):
        after_sbox[1] = b
        for c in range(p):
            after_sbox[2] = c
            k = [ sum([M[i][j] * _ for j, _ in enumerate(after_sbox)]) % p for i in range(len(M)) ]
            if k == actual_result:
                print(a, b, c)
# Then there's still only one result...
# This proves, at least in our case, that every unique t-1 length combination provides a unique t-length comination