LatticeX-Foundation / Rosetta

A Privacy-Preserving Framework Based on TensorFlow
GNU Lesser General Public License v3.0
568 stars 110 forks source link

三维矩阵相乘Rosetta不支持吗 #119

Closed xiaoshui240 closed 1 year ago

xiaoshui240 commented 2 years ago

我尝试运行三维tensor相乘,在tensorflow中,只需要第一个维度相同,得到结果是后面两个维度相乘。但Rosetta似乎不支持? 报错: ValueError: Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2]. 代码如下:

#!/usr/bin/env python3

# Import rosetta package
import latticex.rosetta as rtt
import tensorflow as tf

# You can activate a backend protocol, here use SecureNN
rtt.activate("SecureNN")

# Get private data from Alice (input x), Bob (input y)
x = tf.Variable(rtt.private_input(0, [[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[10, 11, 12]]]))
y = tf.Variable(rtt.private_input(1, [[[13, 14],[15, 16],[17, 18]],[[19, 20],[21, 22],[23, 24]]]))

# Define matmul operation
res = tf.matmul(x, y)

# Start execution
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(res)

    # Get the result of Rosetta matmul
    print('matmul:', sess.run(rtt.SecureReveal(res)))

rtt.deactivate()
YuanFengChen commented 2 years ago

我尝试运行三维tensor相乘,在tensorflow中,只需要第一个维度相同,得到结果是后面两个维度相乘。但Rosetta似乎不支持? 报错: ValueError: Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2]. 代码如下:

#!/usr/bin/env python3

# Import rosetta package
import latticex.rosetta as rtt
import tensorflow as tf

# You can activate a backend protocol, here use SecureNN
rtt.activate("SecureNN")

# Get private data from Alice (input x), Bob (input y)
x = tf.Variable(rtt.private_input(0, [[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[10, 11, 12]]]))
y = tf.Variable(rtt.private_input(1, [[[13, 14],[15, 16],[17, 18]],[[19, 20],[21, 22],[23, 24]]]))

# Define matmul operation
res = tf.matmul(x, y)

# Start execution
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(res)

    # Get the result of Rosetta matmul
    print('matmul:', sess.run(rtt.SecureReveal(res)))

rtt.deactivate()

嗯,目前暂只支持二维,未来再开放多维功能。

xiaoshui240 commented 2 years ago

我尝试运行三维tensor相乘,在tensorflow中,只需要第一个维度相同,得到结果是后面两个维度相乘。但Rosetta似乎不支持? 报错: ValueError: Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2]. 代码如下:

#!/usr/bin/env python3

# Import rosetta package
import latticex.rosetta as rtt
import tensorflow as tf

# You can activate a backend protocol, here use SecureNN
rtt.activate("SecureNN")

# Get private data from Alice (input x), Bob (input y)
x = tf.Variable(rtt.private_input(0, [[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[10, 11, 12]]]))
y = tf.Variable(rtt.private_input(1, [[[13, 14],[15, 16],[17, 18]],[[19, 20],[21, 22],[23, 24]]]))

# Define matmul operation
res = tf.matmul(x, y)

# Start execution
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(res)

    # Get the result of Rosetta matmul
    print('matmul:', sess.run(rtt.SecureReveal(res)))

rtt.deactivate()

嗯,目前暂只支持二维,未来再开放多维功能。

好吧,但要实现多维的话,我可以做哪些方面的努力吗

joyoFeng commented 1 year ago

我尝试运行三维tensor相乘,在tensorflow中,只需要第一个维度相同,得到结果是后面两个维度相乘。但Rosetta似乎不支持? 报错: ValueError: Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2]. 代码如下:

#!/usr/bin/env python3

# Import rosetta package
import latticex.rosetta as rtt
import tensorflow as tf

# You can activate a backend protocol, here use SecureNN
rtt.activate("SecureNN")

# Get private data from Alice (input x), Bob (input y)
x = tf.Variable(rtt.private_input(0, [[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[10, 11, 12]]]))
y = tf.Variable(rtt.private_input(1, [[[13, 14],[15, 16],[17, 18]],[[19, 20],[21, 22],[23, 24]]]))

# Define matmul operation
res = tf.matmul(x, y)

# Start execution
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(res)

    # Get the result of Rosetta matmul
    print('matmul:', sess.run(rtt.SecureReveal(res)))

rtt.deactivate()

嗯,目前暂只支持二维,未来再开放多维功能。

好吧,但要实现多维的话,我可以做哪些方面的努力吗

主要是需要处理tensor shape,转换为了二维矩阵的方式处理。未来不久将考虑添加这些支持更新