secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
226 stars 100 forks source link

使用 SPU 实现逻辑回归算法基础功能 #211

Closed Candicepan closed 1 year ago

Candicepan commented 1 year ago

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)第一期任务 ISSUE,欢迎社区开发者参与共建~

任务介绍

详细要求

能力要求

操作说明

imwangyt commented 1 year ago

imwangyt Give it to me

Candicepan commented 1 year ago

imwangyt Give it to me

经沟通,该任务已经回收,目前为待认领状态哈,欢迎小伙伴们继续认领~

tarantula-leo commented 1 year ago

tarantula-leo Give it to me

tarantula-leo commented 1 year ago
import numpy as np
import jax.numpy as jnp
import jax
from enum import Enum

from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

def sigmoid_sr(x):
    return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class RegType(Enum):
    Linear = 'linear'
    Logistic = 'logistic'

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'

class SGDClassifier:
    def __init__(
        self,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        reg_type: str,
        penalty: str,
        l2_norm: float,
        eps: float,
    ):
        # parameter check.
        assert epochs > 0, f"epochs should >0"
        assert learning_rate > 0, f"learning_rate should >0"
        assert batch_size > 0, f"batch_size should >0"
        assert penalty != 'l1', "not support L1 penalty for now"
        if penalty == Penalty.L2:
            assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
        assert reg_type in [
            e.value for e in RegType
        ], f"reg_type should in {[e.value for e in RegType]}, but got {reg_type}"
        assert penalty in [
            e.value for e in Penalty
        ], f"penalty should in {[e.value for e in Penalty]}, but got {reg_type}"

        self._epochs = epochs
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        self._l2_norm = l2_norm
        self._penalty = Penalty(penalty)
        # TODO: reg_type should not be here.
        self._reg_type = RegType(reg_type)

        self._weights = jnp.zeros(())
        self._eps = eps
    def _update_weights(
        self,
        x,  # array-like
        y,  # array-like
        w,  # array-like
        total_batch: int,
        batch_size: int,
    ) -> np.ndarray:
        assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
        num_feat = x.shape[1]
        assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w = w.reshape((w.shape[0], 1))

        for idx in range(total_batch):
            begin = idx * batch_size
            end = (idx + 1) * batch_size
            # padding one col for bias in w
            x_slice = jnp.concatenate(
                (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
            )
            y_slice = y[begin:end, :]

            pred = jnp.matmul(x_slice, w)
            if self._reg_type == RegType.Logistic:
                pred = sigmoid_sr(pred)

            err = pred - y_slice
            grad = jnp.matmul(jnp.transpose(x_slice), err)

            if self._penalty == Penalty.L2:
                w_with_zero_bias = jnp.resize(w, (num_feat, 1))
                w_with_zero_bias = jnp.concatenate(
                    (w_with_zero_bias, jnp.zeros((1, 1))),
                    axis=0,
                )
                grad = grad + w_with_zero_bias * self._l2_norm

            step = (self._learning_rate * grad) / batch_size

            w = w - step

        return w

    def fit(self, x, y):
        """Fit linear model with Stochastic Gradient Descent.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Training data.

        y : ndarray of shape (n_samples,)
            Target values.

        Returns
        -------
        self : object
            Returns an instance of self.
        """
        assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

        num_sample = x.shape[0]
        num_feat = x.shape[1]
        batch_size = min(self._batch_size, num_sample)
        total_batch = int(num_sample / batch_size)

        weights = jnp.zeros((num_feat + 1, 1))

        def epoch_loop(t):
            early_stop_flag,weights,old_weights,epoch_idx = t
            old_weights = weights
            weights = self._update_weights(
                x,
                y,
                weights,
                total_batch,
                batch_size,
            )

            weights_diff = jnp.linalg.norm(old_weights - weights)
            early_stop_flag = weights_diff > self._eps
            epoch_idx += 1
            return early_stop_flag,weights,weights_diff,epoch_idx

        t = (True,weights,1,0)
        update_data = jax.lax.while_loop(lambda t: jax.lax.bitwise_and(t[0],t[3]<self._epochs),epoch_loop,t)
        self._weights = update_data[1]
        return self

    def predict_proba(self, x):
        """Probability estimates.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples, n_classes)
            Returns the probability of the sample for each class in the model,
            where classes are ordered as they are in `self.classes_`.
        """
        num_feat = x.shape[1]
        w = self._weights
        assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w.reshape((w.shape[0], 1))

        bias = w[-1, 0]
        w = jnp.resize(w, (num_feat, 1))

        pred = jnp.matmul(x, w) + bias

        if self._reg_type == RegType.Logistic:
            pred = sigmoid_sr(pred)
        return pred

plain_model = SGDClassifier(
    epochs=100,
    learning_rate=0.1,
    batch_size=8,
    reg_type='logistic',
    penalty='l2',
    l2_norm=1.0,
    eps = 0.01
)

plain_model.fit(X.values, y.values.reshape(-1, 1))  # X, y should be two-dimension array
predict_prob = plain_model.predict_proba(X.values)
from sklearn.metrics import roc_auc_score
print(plain_model._weights)
print(roc_auc_score(y.values, predict_prob))

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim = spsim.Simulator.simple(
    2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
)

def fit_and_predict(x, y):
    model = SGDClassifier(
        epochs=100,
        learning_rate=0.1,
        batch_size=8,
        reg_type='logistic',
        penalty='l2',
        l2_norm=1.0,
        eps = 0.01
    )
    model.fit(x, y)
    return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1))  # X, y should be two-dimension array

报错信息 [libspu/kernel/hlo/control_flow.cc:133] While with secret condition is not supported 有几个问题想请教下 Q1: 是否只能通过明文转化实现condition判断? Q2: SPU中是否有类似sf.reveal的API接口? Q3: secret condition转明文处理是否各类情况下都只会泄露1 bit信息?

deadlywing commented 1 year ago

import numpy as np import jax.numpy as jnp import jax from enum import Enum

from sklearn.datasets import load_breast_cancer from sklearn.preprocessing import MinMaxScaler import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True) scalar = MinMaxScaler(feature_range=(-2, 2)) cols = X.columns X = scalar.fit_transform(X) X = pd.DataFrame(X, columns=cols)

def sigmoid_sr(x): return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class RegType(Enum): Linear = 'linear' Logistic = 'logistic'

class Penalty(Enum): NONE = 'None' L1 = 'l1' # not supported L2 = 'l2'

class SGDClassifier: def init( self, epochs: int, learning_rate: float, batch_size: int, reg_type: str, penalty: str, l2_norm: float, eps: float, ): # parameter check. assert epochs > 0, f"epochs should >0" assert learning_rate > 0, f"learning_rate should >0" assert batch_size > 0, f"batch_size should >0" assert penalty != 'l1', "not support L1 penalty for now" if penalty == Penalty.L2: assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty" assert reg_type in [ e.value for e in RegType ], f"reg_type should in {[e.value for e in RegType]}, but got {reg_type}" assert penalty in [ e.value for e in Penalty ], f"penalty should in {[e.value for e in Penalty]}, but got {reg_type}"

    self._epochs = epochs
    self._learning_rate = learning_rate
    self._batch_size = batch_size
    self._l2_norm = l2_norm
    self._penalty = Penalty(penalty)
    # TODO: reg_type should not be here.
    self._reg_type = RegType(reg_type)

    self._weights = jnp.zeros(())
    self._eps = eps
def _update_weights(
    self,
    x,  # array-like
    y,  # array-like
    w,  # array-like
    total_batch: int,
    batch_size: int,
) -> np.ndarray:
    assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
    num_feat = x.shape[1]
    assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
    assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
    w = w.reshape((w.shape[0], 1))

    for idx in range(total_batch):
        begin = idx * batch_size
        end = (idx + 1) * batch_size
        # padding one col for bias in w
        x_slice = jnp.concatenate(
            (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
        )
        y_slice = y[begin:end, :]

        pred = jnp.matmul(x_slice, w)
        if self._reg_type == RegType.Logistic:
            pred = sigmoid_sr(pred)

        err = pred - y_slice
        grad = jnp.matmul(jnp.transpose(x_slice), err)

        if self._penalty == Penalty.L2:
            w_with_zero_bias = jnp.resize(w, (num_feat, 1))
            w_with_zero_bias = jnp.concatenate(
                (w_with_zero_bias, jnp.zeros((1, 1))),
                axis=0,
            )
            grad = grad + w_with_zero_bias * self._l2_norm

        step = (self._learning_rate * grad) / batch_size

        w = w - step

    return w

def fit(self, x, y):
    """Fit linear model with Stochastic Gradient Descent.

    Parameters
    ----------
    X : {array-like}, shape (n_samples, n_features)
        Training data.

    y : ndarray of shape (n_samples,)
        Target values.

    Returns
    -------
    self : object
        Returns an instance of self.
    """
    assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

    num_sample = x.shape[0]
    num_feat = x.shape[1]
    batch_size = min(self._batch_size, num_sample)
    total_batch = int(num_sample / batch_size)

    weights = jnp.zeros((num_feat + 1, 1))

    def epoch_loop(t):
        early_stop_flag,weights,old_weights,epoch_idx = t
        old_weights = weights
        weights = self._update_weights(
            x,
            y,
            weights,
            total_batch,
            batch_size,
        )

        weights_diff = jnp.linalg.norm(old_weights - weights)
        early_stop_flag = weights_diff > self._eps
        epoch_idx += 1
        return early_stop_flag,weights,weights_diff,epoch_idx

    t = (True,weights,1,0)
    update_data = jax.lax.while_loop(lambda t: jax.lax.bitwise_and(t[0],t[3]<self._epochs),epoch_loop,t)
    self._weights = update_data[1]
    return self

def predict_proba(self, x):
    """Probability estimates.

    Parameters
    ----------
    X : {array-like}, shape (n_samples, n_features)
        Input data for prediction.

    Returns
    -------
    ndarray of shape (n_samples, n_classes)
        Returns the probability of the sample for each class in the model,
        where classes are ordered as they are in `self.classes_`.
    """
    num_feat = x.shape[1]
    w = self._weights
    assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
    assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
    w.reshape((w.shape[0], 1))

    bias = w[-1, 0]
    w = jnp.resize(w, (num_feat, 1))

    pred = jnp.matmul(x, w) + bias

    if self._reg_type == RegType.Logistic:
        pred = sigmoid_sr(pred)
    return pred

plain_model = SGDClassifier( epochs=100, learning_rate=0.1, batch_size=8, reg_type='logistic', penalty='l2', l2_norm=1.0, eps = 0.01 )

plain_model.fit(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array predict_prob = plain_model.predict_proba(X.values) from sklearn.metrics import roc_auc_score print(plain_model._weights) print(roc_auc_score(y.values, predict_prob))

import spu.utils.simulation as spsim import spu.spu_pb2 as spu_pb2 import spu

sim = spsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64 )

def fit_and_predict(x, y): model = SGDClassifier( epochs=100, learning_rate=0.1, batch_size=8, reg_type='logistic', penalty='l2', l2_norm=1.0, eps = 0.01 ) model.fit(x, y) return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array

报错信息 [libspu/kernel/hlo/control_flow.cc:133] While with secret condition is not supported 有几个问题想请教下 Q1: 是否只能通过明文转化实现condition判断? Q2: SPU中是否有类似sf.reveal的API接口? Q3: secret condition转明文处理是否各类情况下都只会泄露1 bit信息?

Q1:我理解是的,否则你无法知道是否需要继续

Q2: spu中是可以将秘文转化为明文的,取决于你运行的方式; a.如果是spsim的方式,那其实函数的输出直接就是明文; b.如果是类似sf的分布式运行,那可以使用spu.utils.distributed里的get方法 不过可能early stop确实有点不适合end to end验证,您可以先不实现这个功能吧~

Q3: 不是很理解各类情况是什么意思?

tarantula-leo commented 1 year ago

Q2:现在的early stop功能在spu中可以通过什么方式实现呢?看其中有这个要求 上述代码的实现方式是需要支持While with secret condition,其他的功能代码中都包含(https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py) Q3:使用条件判断时,明文状态下是bool值 只有1 bit,但在spu中看报错的信息,应该会Trace判断条件的生成过程,中间不会泄露更多的信息吗?

deadlywing commented 1 year ago

Q2:现在的early stop功能在spu中可以通过什么方式实现呢?看其中有这个要求 上述代码的实现方式是需要支持While with secret condition,其他的功能代码中都包含(https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py) Q3:使用条件判断时,明文状态下是bool值 只有1 bit,但在spu中看报错的信息,应该会Trace判断条件的生成过程,中间不会泄露更多的信息吗?

Q2: 如果要实现early stop就需要明密文混合编程,其实现的框架大概会变成:

for _ in range(max_iter):
  flag = spu_device(model.fit)(x)  # run program in spa 
  flag = ppd.reveal()
  # early stop
  if flag:
    break

但我们这次其实是希望end-to-end的实现,即

model.fit()  # do all training work in one fit.

所以,您可以先不用实现early stop,就以一个指定的最大迭代次数即可。

Q3: 额,,还是不理解,spu并不支持while的条件为密文,所以也不会泄漏条件的明文值,至于您说的“条件的生成过程”,是指python报错的trace会指向jax.lax.bitwise_and(t[0],t[3]<self._epochs)么?

deadlywing commented 1 year ago

简单来说,只要您不手动reveal数据,理论上是不会有更多泄漏的。

Thanks

tarantula-leo commented 1 year ago

early stop功能如果不用实现的话,这份代码是不是已经完成了?

import numpy as np
import jax.numpy as jnp

from enum import Enum

from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

def sigmoid_sr(x):
    return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class RegType(Enum):
    Linear = 'linear'
    Logistic = 'logistic'

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'

class SGDClassifier:
    def __init__(
        self,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        reg_type: str,
        penalty: str,
        l2_norm: float,
    ):
        # parameter check.
        assert epochs > 0, f"epochs should >0"
        assert learning_rate > 0, f"learning_rate should >0"
        assert batch_size > 0, f"batch_size should >0"
        assert penalty != 'l1', "not support L1 penalty for now"
        if penalty == Penalty.L2:
            assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
        assert reg_type in [
            e.value for e in RegType
        ], f"reg_type should in {[e.value for e in RegType]}, but got {reg_type}"
        assert penalty in [
            e.value for e in Penalty
        ], f"penalty should in {[e.value for e in Penalty]}, but got {reg_type}"

        self._epochs = epochs
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        self._l2_norm = l2_norm
        self._penalty = Penalty(penalty)
        # TODO: reg_type should not be here.
        self._reg_type = RegType(reg_type)

        self._weights = jnp.zeros(())

    def _update_weights(
        self,
        x,  # array-like
        y,  # array-like
        w,  # array-like
        total_batch: int,
        batch_size: int,
    ) -> np.ndarray:
        assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
        num_feat = x.shape[1]
        assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w = w.reshape((w.shape[0], 1))

        for idx in range(total_batch):
            begin = idx * batch_size
            end = (idx + 1) * batch_size
            # padding one col for bias in w
            x_slice = jnp.concatenate(
                (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
            )
            y_slice = y[begin:end, :]

            pred = jnp.matmul(x_slice, w)
            if self._reg_type == RegType.Logistic:
                pred = sigmoid_sr(pred)

            err = pred - y_slice
            grad = jnp.matmul(jnp.transpose(x_slice), err)

            if self._penalty == Penalty.L2:
                w_with_zero_bias = jnp.resize(w, (num_feat, 1))
                w_with_zero_bias = jnp.concatenate(
                    (w_with_zero_bias, jnp.zeros((1, 1))),
                    axis=0,
                )
                grad = grad + w_with_zero_bias * self._l2_norm

            step = (self._learning_rate * grad) / batch_size

            w = w - step

        return w

    def fit(self, x, y):
        """Fit linear model with Stochastic Gradient Descent.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Training data.

        y : ndarray of shape (n_samples,)
            Target values.

        Returns
        -------
        self : object
            Returns an instance of self.
        """
        assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

        num_sample = x.shape[0]
        num_feat = x.shape[1]
        batch_size = min(self._batch_size, num_sample)
        total_batch = int(num_sample / batch_size)

        weights = jnp.zeros((num_feat + 1, 1))

        # do train
        for _ in range(self._epochs):
            weights = self._update_weights(
                x,
                y,
                weights,
                total_batch,
                batch_size,
            )

        self._weights = weights
        return self

    def predict_proba(self, x):
        """Probability estimates.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples, n_classes)
            Returns the probability of the sample for each class in the model,
            where classes are ordered as they are in `self.classes_`.
        """
        num_feat = x.shape[1]
        w = self._weights
        assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w.reshape((w.shape[0], 1))

        bias = w[-1, 0]
        w = jnp.resize(w, (num_feat, 1))

        pred = jnp.matmul(x, w) + bias

        if self._reg_type == RegType.Logistic:
            pred = sigmoid_sr(pred)
        return pred

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim = spsim.Simulator.simple(
    2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
)

def fit_and_predict(x, y):
    model = SGDClassifier(
        epochs=3,
        learning_rate=0.1,
        batch_size=8,
        reg_type='logistic',
        penalty='l2',
        l2_norm=1.0
    )
    model.fit(x, y)
    return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1))
from sklearn.metrics import roc_auc_score
print(roc_auc_score(y.values, result))

Output: 0.9919665979599387

deadlywing commented 1 year ago

early stop功能如果不用实现的话,这份代码是不是已经完成了?

import numpy as np
import jax.numpy as jnp

from enum import Enum

from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

def sigmoid_sr(x):
    return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class RegType(Enum):
    Linear = 'linear'
    Logistic = 'logistic'

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'

class SGDClassifier:
    def __init__(
        self,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        reg_type: str,
        penalty: str,
        l2_norm: float,
    ):
        # parameter check.
        assert epochs > 0, f"epochs should >0"
        assert learning_rate > 0, f"learning_rate should >0"
        assert batch_size > 0, f"batch_size should >0"
        assert penalty != 'l1', "not support L1 penalty for now"
        if penalty == Penalty.L2:
            assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
        assert reg_type in [
            e.value for e in RegType
        ], f"reg_type should in {[e.value for e in RegType]}, but got {reg_type}"
        assert penalty in [
            e.value for e in Penalty
        ], f"penalty should in {[e.value for e in Penalty]}, but got {reg_type}"

        self._epochs = epochs
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        self._l2_norm = l2_norm
        self._penalty = Penalty(penalty)
        # TODO: reg_type should not be here.
        self._reg_type = RegType(reg_type)

        self._weights = jnp.zeros(())

    def _update_weights(
        self,
        x,  # array-like
        y,  # array-like
        w,  # array-like
        total_batch: int,
        batch_size: int,
    ) -> np.ndarray:
        assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
        num_feat = x.shape[1]
        assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w = w.reshape((w.shape[0], 1))

        for idx in range(total_batch):
            begin = idx * batch_size
            end = (idx + 1) * batch_size
            # padding one col for bias in w
            x_slice = jnp.concatenate(
                (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
            )
            y_slice = y[begin:end, :]

            pred = jnp.matmul(x_slice, w)
            if self._reg_type == RegType.Logistic:
                pred = sigmoid_sr(pred)

            err = pred - y_slice
            grad = jnp.matmul(jnp.transpose(x_slice), err)

            if self._penalty == Penalty.L2:
                w_with_zero_bias = jnp.resize(w, (num_feat, 1))
                w_with_zero_bias = jnp.concatenate(
                    (w_with_zero_bias, jnp.zeros((1, 1))),
                    axis=0,
                )
                grad = grad + w_with_zero_bias * self._l2_norm

            step = (self._learning_rate * grad) / batch_size

            w = w - step

        return w

    def fit(self, x, y):
        """Fit linear model with Stochastic Gradient Descent.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Training data.

        y : ndarray of shape (n_samples,)
            Target values.

        Returns
        -------
        self : object
            Returns an instance of self.
        """
        assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

        num_sample = x.shape[0]
        num_feat = x.shape[1]
        batch_size = min(self._batch_size, num_sample)
        total_batch = int(num_sample / batch_size)

        weights = jnp.zeros((num_feat + 1, 1))

        # do train
        for _ in range(self._epochs):
            weights = self._update_weights(
                x,
                y,
                weights,
                total_batch,
                batch_size,
            )

        self._weights = weights
        return self

    def predict_proba(self, x):
        """Probability estimates.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples, n_classes)
            Returns the probability of the sample for each class in the model,
            where classes are ordered as they are in `self.classes_`.
        """
        num_feat = x.shape[1]
        w = self._weights
        assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w.reshape((w.shape[0], 1))

        bias = w[-1, 0]
        w = jnp.resize(w, (num_feat, 1))

        pred = jnp.matmul(x, w) + bias

        if self._reg_type == RegType.Logistic:
            pred = sigmoid_sr(pred)
        return pred

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim = spsim.Simulator.simple(
    2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
)

def fit_and_predict(x, y):
    model = SGDClassifier(
        epochs=3,
        learning_rate=0.1,
        batch_size=8,
        reg_type='logistic',
        penalty='l2',
        l2_norm=1.0
    )
    model.fit(x, y)
    return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1))
from sklearn.metrics import roc_auc_score
print(roc_auc_score(y.values, result))

Output: 0.9919665979599387

那您可以先发pr,我这边review然后运行一下。 需要提交的文件和位置可以参考 https://github.com/secretflow/spu/pull/240

tarantula-leo commented 1 year ago

https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py 主体内容是这里的代码,和任务要求的区别应该只差early weight

deadlywing commented 1 year ago

https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py 主体内容是这里的代码,和任务要求的区别应该只差early weight

嗯嗯,,功能上那份代码是满足的,,但是那份代码主要是我们实现了用作实例的,api设计比较随意。

我们希望api能尽量接近sklearn的lr分类器(这个算法里也不需要线性回归模型的支持,后续我们会有专门的线性回归模型任务)。您可以参考sklearn的LR稍微重构一下api,并预留一些参数。 我们后续预计会增加

  1. 多分类支持
  2. L1,elastic正则支持
  3. Class weight支持 等等
tarantula-leo commented 1 year ago

类似这样么?

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'
    Elastic = 'elasticnet' # not supported
deadlywing commented 1 year ago

嗯嗯,像这种能枚举的可以这样。 其余的比如class weight这些就留个口子,确保默认是走的2分类分支就行

Candicepan commented 1 year ago

类似这样么?

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'
    Elastic = 'elasticnet' # not supported

赞~辛苦您整理完相关内容之后 以 PR 的形式提交哈~ 期待您的 PR

tarantula-leo commented 1 year ago
import numpy as np
import jax.numpy as jnp

from enum import Enum

def sigmoid_sr(x):
    return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class Penalty(Enum):
    NONE = 'None'
    L1 = 'l1'  # not supported
    L2 = 'l2'
    Elastic = 'elasticnet' # not supported

class Multi_class(Enum):
    Ovr = 'ovr'
    Multy = 'multinomial' # not supported

class SGDClassifier:
    def __init__(
        self,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        penalty: str,
        l2_norm: float,
        class_weight: None,
        multi_class: str,
    ):
        # parameter check.
        assert epochs > 0, f"epochs should >0"
        assert learning_rate > 0, f"learning_rate should >0"
        assert batch_size > 0, f"batch_size should >0"
        assert penalty == 'l2', "only support L2 penalty for now"
        if penalty == Penalty.L2:
            assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
        assert penalty in [
            e.value for e in Penalty
        ], f"penalty should in {[e.value for e in Penalty]}, but got {penalty}"
        assert class_weight == None, f"not support class_weight for now"
        assert multi_class == 'ovr', f"only support binary problem for now"

        self._epochs = epochs
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        self._l2_norm = l2_norm
        self._penalty = Penalty(penalty)
        self._class_weight = class_weight
        self._multi_class = Multi_class(multi_class)

        self._weights = jnp.zeros(())

    def _update_weights(
        self,
        x,  # array-like
        y,  # array-like
        w,  # array-like
        total_batch: int,
        batch_size: int,
    ) -> np.ndarray:
        assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
        num_feat = x.shape[1]
        assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w = w.reshape((w.shape[0], 1))

        for idx in range(total_batch):
            begin = idx * batch_size
            end = (idx + 1) * batch_size
            # padding one col for bias in w
            x_slice = jnp.concatenate(
                (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
            )
            y_slice = y[begin:end, :]

            pred = jnp.matmul(x_slice, w)
            pred = sigmoid_sr(pred)

            err = pred - y_slice
            grad = jnp.matmul(jnp.transpose(x_slice), err)

            if self._penalty == Penalty.L2:
                w_with_zero_bias = jnp.resize(w, (num_feat, 1))
                w_with_zero_bias = jnp.concatenate(
                    (w_with_zero_bias, jnp.zeros((1, 1))),
                    axis=0,
                )
                grad = grad + w_with_zero_bias * self._l2_norm
            elif self._penalty == Penalty.L1:
                pass
            elif self._penalty == Penalty.Elastic:
                pass

            step = (self._learning_rate * grad) / batch_size

            w = w - step

        return w

    def fit(self, x, y):
        """Fit linear model with Stochastic Gradient Descent.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Training data.

        y : ndarray of shape (n_samples,)
            Target values.

        Returns
        -------
        self : object
            Returns an instance of self.
        """
        assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

        num_sample = x.shape[0]
        num_feat = x.shape[1]
        batch_size = min(self._batch_size, num_sample)
        total_batch = int(num_sample / batch_size)
        weights = jnp.zeros((num_feat + 1, 1))

        # not support class_weight for now
        if isinstance(self._class_weight, dict):
            pass
        elif self._class_weight == 'balanced':
            pass

        # do train
        for _ in range(self._epochs):
            weights = self._update_weights(
                x,
                y,
                weights,
                total_batch,
                batch_size,
            )

        self._weights = weights
        return self

    def predict_proba(self, x):
        """Probability estimates.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples, n_classes)
            Returns the probability of the sample for each class in the model,
            where classes are ordered as they are in `self.classes_`.
        """
        if self._multi_class == Multi_class.Ovr:   
            num_feat = x.shape[1]
            w = self._weights
            assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
            assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
            w.reshape((w.shape[0], 1))

            bias = w[-1, 0]
            w = jnp.resize(w, (num_feat, 1))
            pred = jnp.matmul(x, w) + bias
            pred = sigmoid_sr(pred)
            return pred
        elif self._multi_class == Multi_class.Multy:
            # not support multi_class problem for now
            pass

# Test
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim = spsim.Simulator.simple(
    2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
)

def fit_and_predict(x, y):
    model = SGDClassifier(
        epochs=3,
        learning_rate=0.1,
        batch_size=8,
        penalty='l2',
        l2_norm=1.0,
        class_weight=None,
        multi_class='ovr'
    )
    model.fit(x, y)
    return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1))  # X, y should be two-dimension array
print(result)
from sklearn.metrics import roc_auc_score
print(roc_auc_score(y.values, result))
deadlywing commented 1 year ago

import numpy as np import jax.numpy as jnp

from enum import Enum

def sigmoid_sr(x): return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5

class Penalty(Enum): NONE = 'None' L1 = 'l1' # not supported L2 = 'l2' Elastic = 'elasticnet' # not supported

class Multi_class(Enum): Ovr = 'ovr' Multy = 'multinomial' # not supported

class SGDClassifier: def init( self, epochs: int, learning_rate: float, batch_size: int, penalty: str, l2_norm: float, class_weight: None, multi_class: str, ):

parameter check.

    assert epochs > 0, f"epochs should >0"
    assert learning_rate > 0, f"learning_rate should >0"
    assert batch_size > 0, f"batch_size should >0"
    assert penalty == 'l2', "only support L2 penalty for now"
    if penalty == Penalty.L2:
        assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
    assert penalty in [
        e.value for e in Penalty
    ], f"penalty should in {[e.value for e in Penalty]}, but got {penalty}"
    assert class_weight == None, f"not support class_weight for now"
    assert multi_class == 'ovr', f"only support binary problem for now"

    self._epochs = epochs
    self._learning_rate = learning_rate
    self._batch_size = batch_size
    self._l2_norm = l2_norm
    self._penalty = Penalty(penalty)
    self._class_weight = class_weight
    self._multi_class = Multi_class(multi_class)

    self._weights = jnp.zeros(())

def _update_weights(
    self,
    x,  # array-like
    y,  # array-like
    w,  # array-like
    total_batch: int,
    batch_size: int,
) -> np.ndarray:
    assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
    num_feat = x.shape[1]
    assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
    assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
    w = w.reshape((w.shape[0], 1))

    for idx in range(total_batch):
        begin = idx * batch_size
        end = (idx + 1) * batch_size
        # padding one col for bias in w
        x_slice = jnp.concatenate(
            (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1
        )
        y_slice = y[begin:end, :]

        pred = jnp.matmul(x_slice, w)
        pred = sigmoid_sr(pred)

        err = pred - y_slice
        grad = jnp.matmul(jnp.transpose(x_slice), err)

        if self._penalty == Penalty.L2:
            w_with_zero_bias = jnp.resize(w, (num_feat, 1))
            w_with_zero_bias = jnp.concatenate(
                (w_with_zero_bias, jnp.zeros((1, 1))),
                axis=0,
            )
            grad = grad + w_with_zero_bias * self._l2_norm
        elif self._penalty == Penalty.L1:
            pass
        elif self._penalty == Penalty.Elastic:
            pass

        step = (self._learning_rate * grad) / batch_size

        w = w - step

    return w

def fit(self, x, y):
    """Fit linear model with Stochastic Gradient Descent.

    Parameters
    ----------
    X : {array-like}, shape (n_samples, n_features)
        Training data.

    y : ndarray of shape (n_samples,)
        Target values.

    Returns
    -------
    self : object
        Returns an instance of self.
    """
    assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"

    num_sample = x.shape[0]
    num_feat = x.shape[1]
    batch_size = min(self._batch_size, num_sample)
    total_batch = int(num_sample / batch_size)
    weights = jnp.zeros((num_feat + 1, 1))

    # not support class_weight for now
    if isinstance(self._class_weight, dict):
        pass
    elif self._class_weight == 'balanced':
        pass

    # do train
    for _ in range(self._epochs):
        weights = self._update_weights(
            x,
            y,
            weights,
            total_batch,
            batch_size,
        )

    self._weights = weights
    return self

def predict_proba(self, x):
    """Probability estimates.

    Parameters
    ----------
    X : {array-like}, shape (n_samples, n_features)
        Input data for prediction.

    Returns
    -------
    ndarray of shape (n_samples, n_classes)
        Returns the probability of the sample for each class in the model,
        where classes are ordered as they are in `self.classes_`.
    """
    if self._multi_class == Multi_class.Ovr:   
        num_feat = x.shape[1]
        w = self._weights
        assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w.reshape((w.shape[0], 1))

        bias = w[-1, 0]
        w = jnp.resize(w, (num_feat, 1))
        pred = jnp.matmul(x, w) + bias
        pred = sigmoid_sr(pred)
        return pred
    elif self._multi_class == Multi_class.Multy:
        # not support multi_class problem for now
        pass

Test

from sklearn.datasets import load_breast_cancer from sklearn.preprocessing import MinMaxScaler import pandas as pd

X, y = load_breast_cancer(return_X_y=True, as_frame=True) scalar = MinMaxScaler(feature_range=(-2, 2)) cols = X.columns X = scalar.fit_transform(X) X = pd.DataFrame(X, columns=cols)

import spu.utils.simulation as spsim import spu.spu_pb2 as spu_pb2 import spu

sim = spsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64 )

def fit_and_predict(x, y): model = SGDClassifier( epochs=3, learning_rate=0.1, batch_size=8, penalty='l2', l2_norm=1.0, class_weight=None, multi_class='ovr' ) model.fit(x, y) return model.predict_proba(x)

result = spsim.sim_jax(sim, fit_and_predict)(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array print(result) from sklearn.metrics import roc_auc_score print(roc_auc_score(y.values, result))

感谢您的提交,正确性是没有问题的。但有几个点需要麻烦您再修改一下:

  1. 我注意到您sigmoid是固定使用sr近似,这点我觉得不是很合理,您可以参考一下sf的设置,提供可选的sigmoid近似方法的超参数(https://github.com/secretflow/secretflow/blob/main/secretflow/utils/sigmoid.py)
  2. 感谢预留了Multi_class作为多分类的入口;辛苦您增加一些注释,解释一下ovr和multinomial的含义,利于后面的使用者。另外,Multi_class的命名似乎有点奇怪(一般有下划线的话就都是小写?或者就干脆驼峰命名?),您可以考虑修改为MultiClass
  3. 感谢您提供了simulation的测试样例,但还缺少了emulation的测试文件

另外,下次您可以正式发起PR了,并将上述代码进行拆分,需要提交的文件和位置可以参考 #240

tarantula-leo commented 1 year ago

PR的文件路径是在原liner_model下更改还是需要新建路径?

deadlywing commented 1 year ago

PR的文件路径是在原liner_model下更改还是需要新建路径?

新建一个路径吧,以后LR相关的内容都会更新在这里,,原liner_model还是作为一个样例文件夹

感谢