Closed Candicepan closed 1 year ago
hacker-jerry give it to me.
您好,我使用jax实现了一个 pca 的类原型。
from functools import partial
from typing import NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
import unittest
import json
import jax.numpy as jnp
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2 #
class PCA(NamedTuple):
components: jax.Array
means: jax.Array
explained_variance: jax.Array
def transform(state, x):
x = x - state.means
return jnp.dot(x, jnp.transpose(state.components))
def recover(state, x):
return jnp.dot(x, state.components) + state.means
def fit(x, n_components, solver="full", rng=None):
if solver == "full":
return _fit_full(x, n_components)
elif solver == "randomized":
if rng is None:
rng = jax.random.PRNGKey(n_components)
return _fit_randomized(x, n_components, rng)
else:
raise ValueError("solver parameter is not correct")
# @partial(jax.jit, static_argnums=(1,))
def fit_and_transform(x, n_components=2):
state = fit(x, n_components)
return transform(state, x)
# @partial(jax.jit, static_argnames=["n_components"])
def _fit_full(x, n_components):
n_samples, n_features = x.shape
# Subtract the mean of the input data
means = x.mean(axis=0, keepdims=True)
x = x - means
# Factorize the data matrix with singular value decomposition.
U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False)
# Compute the explained variance
explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
# Return the transformation matrix
A = Vt[:n_components]
return PCA(components=A, means=means, explained_variance=explained_variance)
def _fit_randomized(x, n_components, rng, n_iter=5):
"""Randomized PCA based on Halko et al [https://doi.org/10.48550/arXiv.1007.5510]."""
n_samples, n_features = x.shape
means = jnp.mean(x, axis=0, keepdims=True)
x = x - means
# Generate n_features normal vectors of the given size
size = jnp.minimum(2 * n_components, n_features)
Q = jax.random.normal(rng, shape=(n_features, size))
def step_fn(q, _):
q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True)
q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True)
return q, None
Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter)
Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic")
B = Q.T @ x
_, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False)
explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
A = Vt[:n_components]
return PCA(components=A, means=means, explained_variance=explained_variance)
算法可以通过
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
state = fit(X ,n_components=2)
X_pca = transform(state, X)
X_recovered = recover(state, X_pca)
直接调用。也可以通过pdd 的方式进行模拟
import spu.utils.distributed as ppd
import numpy as np
# initialized the distributed environment.
ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)
def make_x():
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
return X
def fit_(X, n_components):
return fit(X,n_components=n_components)
def transform_(state, X):
return transform(state, X)
def get_variance(state):
return state.explained_variance
x = ppd.device("P1")(make_x)()
pca_ = ppd.device("P1")(fit_)(x,n_components=2)
trans_x = ppd.device("SPU")(transform_)(pca_, x)
var = ppd.device("SPU")(get_variance)(pca_)
但是,我在使用spsim 进行模拟的时候,发生报错
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
result = spsim.sim_jax(sim, fit_and_transform, static_argnums=(1,))(X, 2)
File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:41, in _jax_compilation(fn, static_argnums, args, kwargs)
37 @cached(cache=LRUCache(maxsize=128), key=_jax_compilation_key)
38 def _jax_compilation(fn: Callable, static_argnums, args: List, kwargs: Dict):
39 import jax
---> 41 cfn, output = jax.xla_computation(
42 fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
43 )(*args, **kwargs)
44 return cfn.as_serialized_hlo_module_proto(), output
[... skipping hidden 21 frame]
File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/jax/_src/interpreters/mlir.py:1155, in jaxpr_subcomp(ctx, jaxpr, tokens, consts, dim_var_values, *args)
1153 rule = xla_fallback_lowering(eqn.primitive)
1154 else:
-> 1155 raise NotImplementedError(
1156 f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
1157 f"found for platform {ctx.platform}")
1159 eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
1160 effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter
请问应该如何修改?
By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭
NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter
Hi @hacker-jerry,这个看上去是因为 SPU并未支持JAX所有的算子(比如 eigh
),请 @anakinxc 帮忙看一眼
Hi @hacker-jerry
感谢提供复现代码,我们研究一下
Thanks
By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭
能麻烦提供一下本地使用的package版本么?(主要是spu和jax)
spu 0.3.3b0 jax 0.4.8 jaxlib 0.4.7
spu 0.3.3b0 jax 0.4.8 jaxlib 0.4.7
Thanks,我发现你使用ppd和spsim的方式不太一致: 在ppd中 fit 方法是明文计算的 但是在spsim中,fit部分是秘文下进行的
谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?
谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?
首先,我理解fit方法应该是需要能在密态下执行的,所以本质上的原因应该是SPU暂时没支持svd算子。 可能需要您自己实现一下svd算法;
好的,谢谢!
您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
class PCA:
def __init__(self, n_components=None, tol=1e-8, max_iters=100):
self.n_components = n_components
self.tol = tol
self.max_iters = max_iters
self.components_ = None
self.explained_variance_ = None
self.mean_ = None
def fit_transform(self, X):
self.mean_ = jnp.mean(X, axis=0)
X_centered = X - self.mean_
cov_matrix = jnp.cov(X_centered, rowvar=False)
eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)
idx = jnp.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
if self.n_components is None:
self.n_components = X.shape[1]
self.components_ = eigenvectors[:, :self.n_components]
self.explained_variance_ = eigenvalues[:self.n_components]
X_transformed = jnp.dot(X_centered, self.components_)
return X_transformed, self.explained_variance_
def jacobi_eigh(A, tol, max_iters):
n = A.shape[0]
Q = jnp.eye(n)
def body_fn(i, vals):
A, Q = vals
p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
rotation = jnp.eye(n)
rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
rotation = rotation.at[q, p].set(jnp.sin(phi))
rotation = rotation.at[p, q].set(-jnp.sin(phi))
A_prime = rotation.T @ A @ rotation
Q_prime = Q @ rotation
A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
return A, Q
A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))
return jnp.diag(A), Q
该函数可以通过jit编译后调用
pca = PCA(n_components=2)
pca_fit_transform = jit(pca.fit_transform, static_argnums=1)
# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)
print(explained_variance)
print(X_transformed)
但是使用spism模拟时,再次发生报错
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu
sim_aby = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
def fit_transform(X, n_components=None):
pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
X_transformed, explained_variance = pca_fit_transform(X)
return X_transformed, explained_variance
result = spsim.sim_jax(sim_aby, fit_transform, static_argnums=(1,))(X,2)
报错信息如下:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[71], line 1
----> 1 result = spsim.sim_jax(sim_aby, fit_transform, static_argnums=(1,))(X,2)
File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs)
149 def outputNameGen(out_flat):
150 return [f'out{idx}' for idx in range(len(out_flat))]
--> 152 executable, output = spu_fe.compile(
153 spu_fe.Kind.JAX,
154 fun,
155 args,
156 kwargs,
157 in_names,
158 in_vis,
159 outputNameGen,
160 static_argnums=static_argnums,
161 )
163 wrapper.pphlo = executable.code.decode("utf-8")
165 out_flat = sim(executable, *args_flat)
File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums)
175 ir_type = "mhlo"
176 name = repr(fn)
--> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis)
178 executable = spu_pb2.ExecutableProto(
179 name=name,
180 input_names=input_names,
181 output_names=output_names,
182 code=mlir,
183 )
184 return executable, output
File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis)
150 from google.protobuf.json_format import MessageToJson
152 # todo: rename spu_pb2.XlaMeta to IrMeta?
--> 153 return _spu_compilation(
154 ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
155 )
File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs)
735 except KeyError:
736 pass # key not found
--> 737 v = func(*args, **kwargs)
738 try:
739 cache[k] = v
File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta)
133 @cached(cache=LRUCache(maxsize=128))
134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
135 pp_dir = os.getenv('SPU_IR_DUMP_DIR')
--> 136 return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")
RuntimeError: what:
[libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed
stacktrace:
#0 spu::compiler::FE::doit()+0x178ad5b00
#1 spu::compiler::compile()+0x178ac3ee4
#2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c
#3 pybind11::cpp_function::dispatcher()+0x178a94ac4
#4 cfunction_call_varargs+0x104f243e0
#5 _PyObject_MakeTpCall+0x104f23af0
#6 call_function+0x105010158
#7 _PyEval_EvalFrameDefault+0x10500c83c
#8 function_code_fastcall+0x104f247b4
#9 PyVectorcall_Call+0x104f23fd8
#10 _PyEval_EvalFrameDefault+0x10500caf0
#11 _PyEval_EvalCodeWithName+0x1050057fc
#12 _PyFunction_Vectorcall+0x104f24918
#13 call_function+0x1050100c0
#14 _PyEval_EvalFrameDefault+0x10500c8b8
#15 function_code_fastcall+0x104f247b4
请问是什么原因?
@anakinxc
您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:
import jax import jax.numpy as jnp from jax import jit from functools import partial class PCA: def __init__(self, n_components=None, tol=1e-8, max_iters=100): self.n_components = n_components self.tol = tol self.max_iters = max_iters self.components_ = None self.explained_variance_ = None self.mean_ = None def fit_transform(self, X): self.mean_ = jnp.mean(X, axis=0) X_centered = X - self.mean_ cov_matrix = jnp.cov(X_centered, rowvar=False) eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters) idx = jnp.argsort(eigenvalues)[::-1] eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] if self.n_components is None: self.n_components = X.shape[1] self.components_ = eigenvectors[:, :self.n_components] self.explained_variance_ = eigenvalues[:self.n_components] X_transformed = jnp.dot(X_centered, self.components_) return X_transformed, self.explained_variance_ def jacobi_eigh(A, tol, max_iters): n = A.shape[0] Q = jnp.eye(n) def body_fn(i, vals): A, Q = vals p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape) phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p])) rotation = jnp.eye(n) rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi)) rotation = rotation.at[q, p].set(jnp.sin(phi)) rotation = rotation.at[p, q].set(-jnp.sin(phi)) A_prime = rotation.T @ A @ rotation Q_prime = Q @ rotation A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None) Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None) return A, Q A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q)) return jnp.diag(A), Q
该函数可以通过jit编译后调用
pca = PCA(n_components=2) pca_fit_transform = jit(pca.fit_transform, static_argnums=1) # Prepare some data X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 使用编译后的fit_transform函数进行拟合和转换 X_transformed, explained_variance = pca_fit_transform(X) print(explained_variance) print(X_transformed)
但是使用spism模拟时,再次发生报错
import spu.utils.simulation as spsim import spu.spu_pb2 as spu_pb2 import spu sim_aby = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) def fit_transform(X, n_components=None): pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1) X_transformed, explained_variance = pca_fit_transform(X) return X_transformed, explained_variance result = spsim.sim_jax(sim_aby, fit_transform, static_argnums=(1,))(X,2)
报错信息如下:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[71], line 1 ----> 1 result = spsim.sim_jax(sim_aby, fit_transform, static_argnums=(1,))(X,2) File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs) 149 def outputNameGen(out_flat): 150 return [f'out{idx}' for idx in range(len(out_flat))] --> 152 executable, output = spu_fe.compile( 153 spu_fe.Kind.JAX, 154 fun, 155 args, 156 kwargs, 157 in_names, 158 in_vis, 159 outputNameGen, 160 static_argnums=static_argnums, 161 ) 163 wrapper.pphlo = executable.code.decode("utf-8") 165 out_flat = sim(executable, *args_flat) File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums) 175 ir_type = "mhlo" 176 name = repr(fn) --> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis) 178 executable = spu_pb2.ExecutableProto( 179 name=name, 180 input_names=input_names, 181 output_names=output_names, 182 code=mlir, 183 ) 184 return executable, output File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis) 150 from google.protobuf.json_format import MessageToJson 152 # todo: rename spu_pb2.XlaMeta to IrMeta? --> 153 return _spu_compilation( 154 ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis)) 155 ) File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs) 735 except KeyError: 736 pass # key not found --> 737 v = func(*args, **kwargs) 738 try: 739 cache[k] = v File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta) 133 @cached(cache=LRUCache(maxsize=128)) 134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str): 135 pp_dir = os.getenv('SPU_IR_DUMP_DIR') --> 136 return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "") RuntimeError: what: [libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed stacktrace: #0 spu::compiler::FE::doit()+0x178ad5b00 #1 spu::compiler::compile()+0x178ac3ee4 #2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c #3 pybind11::cpp_function::dispatcher()+0x178a94ac4 #4 cfunction_call_varargs+0x104f243e0 #5 _PyObject_MakeTpCall+0x104f23af0 #6 call_function+0x105010158 #7 _PyEval_EvalFrameDefault+0x10500c83c #8 function_code_fastcall+0x104f247b4 #9 PyVectorcall_Call+0x104f23fd8 #10 _PyEval_EvalFrameDefault+0x10500caf0 #11 _PyEval_EvalCodeWithName+0x1050057fc #12 _PyFunction_Vectorcall+0x104f24918 #13 call_function+0x1050100c0 #14 _PyEval_EvalFrameDefault+0x10500c8b8 #15 function_code_fastcall+0x104f247b4
请问是什么原因?
hello,不能跑的原因主要是eigh的实现里用到了三角函数,spu当前没有实现,所以报错了;
PLUS,你eigh的实现应该也有问题,我运行了你的eigh
def test_eigh():
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_centered = X - jnp.mean(X, axis=0)
cov_matrix = jnp.cov(X_centered, rowvar=False)
eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, 1e-8, 1000)
# print(eigenvalues)
# print(eigenvectors)
print(cov_matrix @ eigenvectors)
print(eigenvalues * eigenvectors)
print()
eigenvalues, eigenvectors = eigh(cov_matrix)
# print(eigenvalues)
# print(eigenvectors)
print(cov_matrix @ eigenvectors)
print(eigenvalues * eigenvectors)
[[2.2865321e+02 5.4445304e-06 2.5582359e+02]
[2.2865321e+02 5.4445304e-06 2.5582359e+02]
[2.2865321e+02 5.4445304e-06 2.5582359e+02]]
[[1.4377324e+02 1.2175019e-21 2.0135764e+02]
[2.4747901e-06 1.2621775e-29 3.4660027e-06]
[1.6085759e+02 9.3170446e-22 2.2528442e+02]]
[[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]
[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]
[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]]
[[-4.0569081e-08 -1.6165853e-06 1.5588456e+01]
[-1.1870292e-07 1.1621918e-06 1.5588454e+01]
[ 1.5927199e-07 4.5439356e-07 1.5588456e+01]]
jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:
仅供参考~
jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:
- PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
- 一般地,可以考虑QR分解,不需要半正定的条件
仅供参考~
谢谢您的建议,基于此,我重新实现了一下,代码如下:
import jax
import jax.numpy as jnp
from jax import random
class PCA:
def __init__(self, n_components):
self.n_components = n_components
self.mean = None
self.components = None
self.variances = None
def fit(self, X):
self.mean = jnp.mean(X, axis=0)
X = X - self.mean
cov_matrix = jnp.cov(X, rowvar=False)
L = jnp.linalg.cholesky(cov_matrix)
q, r = jnp.linalg.qr(L)
eigvals = jnp.diag(r)
idx = jnp.argsort(eigvals)[::-1][:self.n_components]
self.components = q[:, idx]
self.variances = eigvals[idx]
def transform(self, X):
X = X - self.mean
return jnp.dot(X, self.components)
def fit_and_transform(X, n_components):
pca = PCA(n_components)
pca.fit(X)
return pca.transform(X)
X = random.randint(random.PRNGKey(0), (10,3), 0, 10)
fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)
X_transformed = fit_and_transform_jit(X, 2)
print(X_transformed)
您看是否符合要求? 这次的代码通过了sispm模拟。
jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:
- PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
- 一般地,可以考虑QR分解,不需要半正定的条件
仅供参考~
谢谢您的建议,基于此,我重新实现了一下,代码如下:
import jax import jax.numpy as jnp from jax import random class PCA: def __init__(self, n_components): self.n_components = n_components self.mean = None self.components = None self.variances = None def fit(self, X): self.mean = jnp.mean(X, axis=0) X = X - self.mean cov_matrix = jnp.cov(X, rowvar=False) L = jnp.linalg.cholesky(cov_matrix) q, r = jnp.linalg.qr(L) eigvals = jnp.diag(r) idx = jnp.argsort(eigvals)[::-1][:self.n_components] self.components = q[:, idx] self.variances = eigvals[idx] def transform(self, X): X = X - self.mean return jnp.dot(X, self.components) def fit_and_transform(X, n_components): pca = PCA(n_components) pca.fit(X) return pca.transform(X) X = random.randint(random.PRNGKey(0), (10,3), 0, 10) fit_and_transform_jit = jit(fit_and_transform, static_argnums=1) X_transformed = fit_and_transform_jit(X, 2) print(X_transformed)
您看是否符合要求? 这次的代码通过了sispm模拟。
Sorry, 这应该是spsim的bug,实际上无论cholesky还是qr应该都无法真实的执行,执行到那两个函数的时候似乎python进程会被直接关闭,我们后续应该会修复这个bug。(所以我很好奇,您运行spsim真的能得到PCA transform后的矩阵么?)
所以,你也需要自己手动实现cholesky分解或qr分解。
最后,麻烦您后面提交pr的时候,用注释的方式标记一下之前的实现中,因为spu不支持算子而无法运行的实现方式,后续我们增加这些算子以后可以重新考察这些实现~
感谢!
的确是运行成功了,
您也可以测试一下,
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu
sim_aby = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
result = spsim.sim_jax(sim_aby, fit_and_transform, static_argnums=(1,))(X,2)
的确是运行成功了, 您也可以测试一下,
import spu.utils.simulation as spsim import spu.spu_pb2 as spu_pb2 import spu sim_aby = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) result = spsim.sim_jax(sim_aby, fit_and_transform, static_argnums=(1,))(X,2)
Thanks, 我本地运行会一直卡住,我需要check一下原因. 另外,麻烦您运行一下下面这个代码,看是否会raise除0错误.
def test_run_eigh():
X = jnp.array(np.random.rand(6, 3))
cov_matrix = jnp.cov(X, rowvar=False)
sim_aby = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
print(cov_matrix)
print(jnp.linalg.det(cov_matrix))
print(jnp.linalg.cholesky(cov_matrix))
print(spsim.sim_jax(sim_aby, jnp.linalg.cholesky)(cov_matrix))
print(1 / 0)
我本地运行的版本是
secretflow 0.8.2b1
sf-heu 0.4.3b3
spu 0.3.2b12
jax 0.4.8
jaxlib 0.4.7
上述代码运行结果有除 0 报错
我在jupyter上运行的话,也会报错...
我在jupyter上运行的话,也会报错...
您看看spu降一下级试一下?
我在jupyter上运行的话,也会报错...
您看看spu降一下级试一下?
好的,我试试; btw,您是用linux系统不?
我是用的m1 mac
我是用的m1 mac
我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 您可以先直接使用这两个api去实现pca吧~
建议您实现以后可以:
感谢!
我是用的m1 mac
我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 您可以先直接使用这两个api去实现pca吧~
建议您实现以后可以:
- 检查是否满足特征值分解的定义
- 检查fit后的方差等与sklearn的pca是否一致
感谢!
你好 想问下cholesky分解和qr分解是已经在SPU中支持了还是使用spsim模拟的bug?
我是用的m1 mac
hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12
我是用的m1 mac
hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12
我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题
我是用的m1 mac
hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12
我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题
但是spu还是原来版本
那看来就是jax最新版本不太适配了,,那您就先用现在的版本先开发吧
ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?
ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?
感谢快速响应,可以先参考一下这个kmeans的PR; https://github.com/secretflow/spu/pull/235 一般是一个文件用于实现算法逻辑(jax only),一个文件用于spsim模拟测试,一个文件做emulation测试;
BTW:麻烦请在spsim模拟测试的那个文件中同时提交一下和明文sklearn的结果对比(可以写在不同的unittest里)
Thanks!
Already solved this issue @Candicepan .
任务介绍
详细要求
能力要求
操作说明