secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
237 stars 103 forks source link

使用 SPU 实现主成分分析基础功能 #213

Closed Candicepan closed 1 year ago

Candicepan commented 1 year ago

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)第一期任务 ISSUE,欢迎社区开发者参与共建~

任务介绍

详细要求

能力要求

操作说明

hacker-jerry commented 1 year ago

hacker-jerry give it to me.

hacker-jerry commented 1 year ago

您好,我使用jax实现了一个 pca 的类原型。

from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np

import unittest
import json
import jax.numpy as jnp
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2  # 

class PCA(NamedTuple):
    components: jax.Array
    means: jax.Array
    explained_variance: jax.Array

def transform(state, x):
    x = x - state.means
    return jnp.dot(x, jnp.transpose(state.components))

def recover(state, x):
    return jnp.dot(x, state.components) + state.means

def fit(x, n_components, solver="full", rng=None):
    if solver == "full":
        return _fit_full(x, n_components)
    elif solver == "randomized":
        if rng is None:
            rng = jax.random.PRNGKey(n_components)
        return _fit_randomized(x, n_components, rng)
    else:
        raise ValueError("solver parameter is not correct")

# @partial(jax.jit, static_argnums=(1,))
def fit_and_transform(x, n_components=2):
    state = fit(x, n_components)
    return transform(state, x)

# @partial(jax.jit, static_argnames=["n_components"])
def _fit_full(x, n_components):
    n_samples, n_features = x.shape

    # Subtract the mean of the input data
    means = x.mean(axis=0, keepdims=True)
    x = x - means

    # Factorize the data matrix with singular value decomposition.
    U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False)

    # Compute the explained variance

    explained_variance = (S[:n_components] ** 2) / (n_samples - 1)

    # Return the transformation matrix
    A = Vt[:n_components]
    return PCA(components=A, means=means, explained_variance=explained_variance)

def _fit_randomized(x, n_components, rng, n_iter=5):
    """Randomized PCA based on Halko et al [https://doi.org/10.48550/arXiv.1007.5510]."""
    n_samples, n_features = x.shape
    means = jnp.mean(x, axis=0, keepdims=True)
    x = x - means

    # Generate n_features normal vectors of the given size
    size = jnp.minimum(2 * n_components, n_features)
    Q = jax.random.normal(rng, shape=(n_features, size))

    def step_fn(q, _):
        q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True)
        q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True)
        return q, None

    Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter)
    Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic")
    B = Q.T @ x

    _, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False)

    explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
    A = Vt[:n_components]
    return PCA(components=A, means=means, explained_variance=explained_variance)

算法可以通过

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
state = fit(X ,n_components=2)
X_pca = transform(state, X)
X_recovered = recover(state, X_pca)

直接调用。也可以通过pdd 的方式进行模拟

import spu.utils.distributed as ppd
import numpy as np

# initialized the distributed environment.
ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)

def make_x():
    X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    return X

def fit_(X, n_components):
    return fit(X,n_components=n_components)

def transform_(state, X):
    return transform(state, X)

def get_variance(state):
    return state.explained_variance

x = ppd.device("P1")(make_x)()
pca_ = ppd.device("P1")(fit_)(x,n_components=2)
trans_x = ppd.device("SPU")(transform_)(pca_, x)
var = ppd.device("SPU")(get_variance)(pca_)

但是,我在使用spsim 进行模拟的时候,发生报错

sim  = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim, fit_and_transform, static_argnums=(1,))(X, 2)

File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:41, in _jax_compilation(fn, static_argnums, args, kwargs)
     37 @cached(cache=LRUCache(maxsize=128), key=_jax_compilation_key)
     38 def _jax_compilation(fn: Callable, static_argnums, args: List, kwargs: Dict):
     39     import jax
---> 41     cfn, output = jax.xla_computation(
     42         fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
     43     )(*args, **kwargs)
     44     return cfn.as_serialized_hlo_module_proto(), output

    [... skipping hidden 21 frame]

File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/jax/_src/interpreters/mlir.py:1155, in jaxpr_subcomp(ctx, jaxpr, tokens, consts, dim_var_values, *args)
   1153   rule = xla_fallback_lowering(eqn.primitive)
   1154 else:
-> 1155   raise NotImplementedError(
   1156       f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
   1157       f"found for platform {ctx.platform}")
   1159 eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
   1160 effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))

NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter

请问应该如何修改?

hacker-jerry commented 1 year ago

By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭

rivertalk commented 1 year ago

NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter

Hi @hacker-jerry,这个看上去是因为 SPU并未支持JAX所有的算子(比如 eigh),请 @anakinxc 帮忙看一眼

anakinxc commented 1 year ago

Hi @hacker-jerry

感谢提供复现代码,我们研究一下

Thanks

deadlywing commented 1 year ago

By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭

能麻烦提供一下本地使用的package版本么?(主要是spu和jax)

hacker-jerry commented 1 year ago

spu 0.3.3b0 jax 0.4.8 jaxlib 0.4.7

deadlywing commented 1 year ago

spu 0.3.3b0 jax 0.4.8 jaxlib 0.4.7

Thanks,我发现你使用ppd和spsim的方式不太一致: 在ppd中 fit 方法是明文计算的 但是在spsim中,fit部分是秘文下进行的

hacker-jerry commented 1 year ago

谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?

deadlywing commented 1 year ago

谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?

首先,我理解fit方法应该是需要能在密态下执行的,所以本质上的原因应该是SPU暂时没支持svd算子。 可能需要您自己实现一下svd算法;

  1. 建议可以先实现full solver
  2. randomized solver需要注意的点在于spu内对随机数的支持和明文下不太一致,建议把生成初始随机矩阵Q放在外面
hacker-jerry commented 1 year ago

好的,谢谢!

hacker-jerry commented 1 year ago

您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

class PCA:
    def __init__(self, n_components=None, tol=1e-8, max_iters=100):
        self.n_components = n_components
        self.tol = tol
        self.max_iters = max_iters
        self.components_ = None
        self.explained_variance_ = None
        self.mean_ = None

    def fit_transform(self, X):
        self.mean_ = jnp.mean(X, axis=0)
        X_centered = X - self.mean_
        cov_matrix = jnp.cov(X_centered, rowvar=False)
        eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)

        idx = jnp.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]

        if self.n_components is None:
            self.n_components = X.shape[1]

        self.components_ = eigenvectors[:, :self.n_components]
        self.explained_variance_ = eigenvalues[:self.n_components]

        X_transformed = jnp.dot(X_centered, self.components_)

        return X_transformed, self.explained_variance_

def jacobi_eigh(A, tol, max_iters):
    n = A.shape[0]
    Q = jnp.eye(n)

    def body_fn(i, vals):
        A, Q = vals
        p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
        phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
        rotation = jnp.eye(n)
        rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
        rotation = rotation.at[q, p].set(jnp.sin(phi))
        rotation = rotation.at[p, q].set(-jnp.sin(phi))
        A_prime = rotation.T @ A @ rotation
        Q_prime = Q @ rotation

        A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
        Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
        return A, Q

    A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))

    return jnp.diag(A), Q

该函数可以通过jit编译后调用

pca = PCA(n_components=2)

pca_fit_transform = jit(pca.fit_transform, static_argnums=1)

# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)

print(explained_variance)

print(X_transformed)

但是使用spism模拟时,再次发生报错

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def fit_transform(X, n_components=None):
    pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
    X_transformed, explained_variance = pca_fit_transform(X)
    return X_transformed, explained_variance 

result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

报错信息如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[71], line 1
----> 1 result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs)
    149 def outputNameGen(out_flat):
    150     return [f'out{idx}' for idx in range(len(out_flat))]
--> 152 executable, output = spu_fe.compile(
    153     spu_fe.Kind.JAX,
    154     fun,
    155     args,
    156     kwargs,
    157     in_names,
    158     in_vis,
    159     outputNameGen,
    160     static_argnums=static_argnums,
    161 )
    163 wrapper.pphlo = executable.code.decode("utf-8")
    165 out_flat = sim(executable, *args_flat)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums)
    175     ir_type = "mhlo"
    176     name = repr(fn)
--> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis)
    178 executable = spu_pb2.ExecutableProto(
    179     name=name,
    180     input_names=input_names,
    181     output_names=output_names,
    182     code=mlir,
    183 )
    184 return executable, output

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis)
    150 from google.protobuf.json_format import MessageToJson
    152 # todo: rename spu_pb2.XlaMeta to IrMeta?
--> 153 return _spu_compilation(
    154     ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
    155 )

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs)
    735 except KeyError:
    736     pass  # key not found
--> 737 v = func(*args, **kwargs)
    738 try:
    739     cache[k] = v

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta)
    133 @cached(cache=LRUCache(maxsize=128))
    134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
    135     pp_dir = os.getenv('SPU_IR_DUMP_DIR')
--> 136     return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")

RuntimeError: what: 
    [libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed
stacktrace: 
#0 spu::compiler::FE::doit()+0x178ad5b00
#1 spu::compiler::compile()+0x178ac3ee4
#2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c
#3 pybind11::cpp_function::dispatcher()+0x178a94ac4
#4 cfunction_call_varargs+0x104f243e0
#5 _PyObject_MakeTpCall+0x104f23af0
#6 call_function+0x105010158
#7 _PyEval_EvalFrameDefault+0x10500c83c
#8 function_code_fastcall+0x104f247b4
#9 PyVectorcall_Call+0x104f23fd8
#10 _PyEval_EvalFrameDefault+0x10500caf0
#11 _PyEval_EvalCodeWithName+0x1050057fc
#12 _PyFunction_Vectorcall+0x104f24918
#13 call_function+0x1050100c0
#14 _PyEval_EvalFrameDefault+0x10500c8b8
#15 function_code_fastcall+0x104f247b4

请问是什么原因?

hacker-jerry commented 1 year ago

@anakinxc

deadlywing commented 1 year ago

您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

class PCA:
    def __init__(self, n_components=None, tol=1e-8, max_iters=100):
        self.n_components = n_components
        self.tol = tol
        self.max_iters = max_iters
        self.components_ = None
        self.explained_variance_ = None
        self.mean_ = None

    def fit_transform(self, X):
        self.mean_ = jnp.mean(X, axis=0)
        X_centered = X - self.mean_
        cov_matrix = jnp.cov(X_centered, rowvar=False)
        eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)

        idx = jnp.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]

        if self.n_components is None:
            self.n_components = X.shape[1]

        self.components_ = eigenvectors[:, :self.n_components]
        self.explained_variance_ = eigenvalues[:self.n_components]

        X_transformed = jnp.dot(X_centered, self.components_)

        return X_transformed, self.explained_variance_

def jacobi_eigh(A, tol, max_iters):
    n = A.shape[0]
    Q = jnp.eye(n)

    def body_fn(i, vals):
        A, Q = vals
        p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
        phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
        rotation = jnp.eye(n)
        rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
        rotation = rotation.at[q, p].set(jnp.sin(phi))
        rotation = rotation.at[p, q].set(-jnp.sin(phi))
        A_prime = rotation.T @ A @ rotation
        Q_prime = Q @ rotation

        A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
        Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
        return A, Q

    A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))

    return jnp.diag(A), Q

该函数可以通过jit编译后调用

pca = PCA(n_components=2)

pca_fit_transform = jit(pca.fit_transform, static_argnums=1)

# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)

print(explained_variance)

print(X_transformed)

但是使用spism模拟时,再次发生报错

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def fit_transform(X, n_components=None):
    pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
    X_transformed, explained_variance = pca_fit_transform(X)
    return X_transformed, explained_variance 

result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

报错信息如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[71], line 1
----> 1 result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs)
    149 def outputNameGen(out_flat):
    150     return [f'out{idx}' for idx in range(len(out_flat))]
--> 152 executable, output = spu_fe.compile(
    153     spu_fe.Kind.JAX,
    154     fun,
    155     args,
    156     kwargs,
    157     in_names,
    158     in_vis,
    159     outputNameGen,
    160     static_argnums=static_argnums,
    161 )
    163 wrapper.pphlo = executable.code.decode("utf-8")
    165 out_flat = sim(executable, *args_flat)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums)
    175     ir_type = "mhlo"
    176     name = repr(fn)
--> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis)
    178 executable = spu_pb2.ExecutableProto(
    179     name=name,
    180     input_names=input_names,
    181     output_names=output_names,
    182     code=mlir,
    183 )
    184 return executable, output

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis)
    150 from google.protobuf.json_format import MessageToJson
    152 # todo: rename spu_pb2.XlaMeta to IrMeta?
--> 153 return _spu_compilation(
    154     ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
    155 )

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs)
    735 except KeyError:
    736     pass  # key not found
--> 737 v = func(*args, **kwargs)
    738 try:
    739     cache[k] = v

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta)
    133 @cached(cache=LRUCache(maxsize=128))
    134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
    135     pp_dir = os.getenv('SPU_IR_DUMP_DIR')
--> 136     return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")

RuntimeError: what: 
  [libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed
stacktrace: 
#0 spu::compiler::FE::doit()+0x178ad5b00
#1 spu::compiler::compile()+0x178ac3ee4
#2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c
#3 pybind11::cpp_function::dispatcher()+0x178a94ac4
#4 cfunction_call_varargs+0x104f243e0
#5 _PyObject_MakeTpCall+0x104f23af0
#6 call_function+0x105010158
#7 _PyEval_EvalFrameDefault+0x10500c83c
#8 function_code_fastcall+0x104f247b4
#9 PyVectorcall_Call+0x104f23fd8
#10 _PyEval_EvalFrameDefault+0x10500caf0
#11 _PyEval_EvalCodeWithName+0x1050057fc
#12 _PyFunction_Vectorcall+0x104f24918
#13 call_function+0x1050100c0
#14 _PyEval_EvalFrameDefault+0x10500c8b8
#15 function_code_fastcall+0x104f247b4

请问是什么原因?

hello,不能跑的原因主要是eigh的实现里用到了三角函数,spu当前没有实现,所以报错了;

PLUS,你eigh的实现应该也有问题,我运行了你的eigh

def test_eigh():
    X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    X_centered = X - jnp.mean(X, axis=0)
    cov_matrix = jnp.cov(X_centered, rowvar=False)

    eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, 1e-8, 1000)
    # print(eigenvalues)
    # print(eigenvectors)
    print(cov_matrix @ eigenvectors)
    print(eigenvalues * eigenvectors)

    print()

    eigenvalues, eigenvectors = eigh(cov_matrix)
    # print(eigenvalues)
    # print(eigenvectors)
    print(cov_matrix @ eigenvectors)
    print(eigenvalues * eigenvectors)
[[2.2865321e+02 5.4445304e-06 2.5582359e+02]
 [2.2865321e+02 5.4445304e-06 2.5582359e+02]
 [2.2865321e+02 5.4445304e-06 2.5582359e+02]]
[[1.4377324e+02 1.2175019e-21 2.0135764e+02]
 [2.4747901e-06 1.2621775e-29 3.4660027e-06]
 [1.6085759e+02 9.3170446e-22 2.2528442e+02]]

[[ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]
 [ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]
 [ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]]
[[-4.0569081e-08 -1.6165853e-06  1.5588456e+01]
 [-1.1870292e-07  1.1621918e-06  1.5588454e+01]
 [ 1.5927199e-07  4.5439356e-07  1.5588456e+01]]
deadlywing commented 1 year ago

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

hacker-jerry commented 1 year ago

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

谢谢您的建议,基于此,我重新实现了一下,代码如下:

import jax
import jax.numpy as jnp
from jax import random

class PCA:
    def __init__(self, n_components):
        self.n_components = n_components
        self.mean = None
        self.components = None
        self.variances = None

    def fit(self, X):
        self.mean = jnp.mean(X, axis=0)
        X = X - self.mean

        cov_matrix = jnp.cov(X, rowvar=False)

        L = jnp.linalg.cholesky(cov_matrix)

        q, r = jnp.linalg.qr(L)

        eigvals = jnp.diag(r)

        idx = jnp.argsort(eigvals)[::-1][:self.n_components]

        self.components = q[:, idx]

        self.variances = eigvals[idx]

    def transform(self, X):
        X = X - self.mean
        return jnp.dot(X, self.components)

def fit_and_transform(X, n_components):
    pca = PCA(n_components)
    pca.fit(X)
    return pca.transform(X)

X = random.randint(random.PRNGKey(0), (10,3), 0, 10)

fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)

X_transformed = fit_and_transform_jit(X, 2)

print(X_transformed)

您看是否符合要求? 这次的代码通过了sispm模拟。

deadlywing commented 1 year ago

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

谢谢您的建议,基于此,我重新实现了一下,代码如下:

import jax
import jax.numpy as jnp
from jax import random

class PCA:
    def __init__(self, n_components):
        self.n_components = n_components
        self.mean = None
        self.components = None
        self.variances = None

    def fit(self, X):
        self.mean = jnp.mean(X, axis=0)
        X = X - self.mean

        cov_matrix = jnp.cov(X, rowvar=False)

        L = jnp.linalg.cholesky(cov_matrix)

        q, r = jnp.linalg.qr(L)

        eigvals = jnp.diag(r)

        idx = jnp.argsort(eigvals)[::-1][:self.n_components]

        self.components = q[:, idx]

        self.variances = eigvals[idx]

    def transform(self, X):
        X = X - self.mean
        return jnp.dot(X, self.components)

def fit_and_transform(X, n_components):
    pca = PCA(n_components)
    pca.fit(X)
    return pca.transform(X)

X = random.randint(random.PRNGKey(0), (10,3), 0, 10)

fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)

X_transformed = fit_and_transform_jit(X, 2)

print(X_transformed)

您看是否符合要求? 这次的代码通过了sispm模拟。

Sorry, 这应该是spsim的bug,实际上无论cholesky还是qr应该都无法真实的执行,执行到那两个函数的时候似乎python进程会被直接关闭,我们后续应该会修复这个bug。(所以我很好奇,您运行spsim真的能得到PCA transform后的矩阵么?)

所以,你也需要自己手动实现cholesky分解或qr分解。

最后,麻烦您后面提交pr的时候,用注释的方式标记一下之前的实现中,因为spu不支持算子而无法运行的实现方式,后续我们增加这些算子以后可以重新考察这些实现~

感谢!

hacker-jerry commented 1 year ago

的确是运行成功了,

image

您也可以测试一下,

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim_aby, fit_and_transform,  static_argnums=(1,))(X,2)
deadlywing commented 1 year ago

的确是运行成功了, image 您也可以测试一下,

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim_aby, fit_and_transform,  static_argnums=(1,))(X,2)

Thanks, 我本地运行会一直卡住,我需要check一下原因. 另外,麻烦您运行一下下面这个代码,看是否会raise除0错误.

def test_run_eigh():
    X = jnp.array(np.random.rand(6, 3))
    cov_matrix = jnp.cov(X, rowvar=False)

    sim_aby = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )

    print(cov_matrix)
    print(jnp.linalg.det(cov_matrix))
    print(jnp.linalg.cholesky(cov_matrix))
    print(spsim.sim_jax(sim_aby, jnp.linalg.cholesky)(cov_matrix))
    print(1 / 0)
hacker-jerry commented 1 year ago

我本地运行的版本是

secretflow                    0.8.2b1
sf-heu                        0.4.3b3
spu                           0.3.2b12
jax                           0.4.8
jaxlib                        0.4.7

上述代码运行结果有除 0 报错

image
deadlywing commented 1 year ago

image image

我在jupyter上运行的话,也会报错...

hacker-jerry commented 1 year ago

image image

我在jupyter上运行的话,也会报错...

您看看spu降一下级试一下?

deadlywing commented 1 year ago

image image 我在jupyter上运行的话,也会报错...

您看看spu降一下级试一下?

好的,我试试; btw,您是用linux系统不?

hacker-jerry commented 1 year ago

我是用的m1 mac

deadlywing commented 1 year ago

我是用的m1 mac

我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 您可以先直接使用这两个api去实现pca吧~

建议您实现以后可以:

  1. 检查是否满足特征值分解的定义
  2. 检查fit后的方差等与sklearn的pca是否一致

感谢!

tarantula-leo commented 1 year ago

我是用的m1 mac

我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 您可以先直接使用这两个api去实现pca吧~

建议您实现以后可以:

  1. 检查是否满足特征值分解的定义
  2. 检查fit后的方差等与sklearn的pca是否一致

感谢!

你好 想问下cholesky分解和qr分解是已经在SPU中支持了还是使用spsim模拟的bug?

anakinxc commented 1 year ago

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

hacker-jerry commented 1 year ago

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题

image
hacker-jerry commented 1 year ago

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题

image

但是spu还是原来版本

deadlywing commented 1 year ago

那看来就是jax最新版本不太适配了,,那您就先用现在的版本先开发吧

hacker-jerry commented 1 year ago

ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?

deadlywing commented 1 year ago

ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?

感谢快速响应,可以先参考一下这个kmeans的PR; https://github.com/secretflow/spu/pull/235 一般是一个文件用于实现算法逻辑(jax only),一个文件用于spsim模拟测试,一个文件做emulation测试;

BTW:麻烦请在spsim模拟测试的那个文件中同时提交一下和明文sklearn的结果对比(可以写在不同的unittest里)

Thanks!

hacker-jerry commented 1 year ago

Already solved this issue @Candicepan .

240