secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
235 stars 101 forks source link

[Bug]: Report RuntimeError when i use SPU to nn training #869

Closed TY-cc closed 3 weeks ago

TY-cc commented 3 weeks ago

Issue Type

Support

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

0.9.3

OS Platform and Distribution

linux ubuntu 22.04

Python Version

Python 3.10.4

Compiler Version

GCC 11.4

Current Behavior?

I am using spu for neural network training. When I first executed it, the source code could run normally. The second time I ran it, I forcibly interrupted the execution of the source code without completing the training. The third time I run it, it report a RuntimeError. code as follow:

Standalone code to reproduce the issue

#https://secretflow.readthedocs.io/zh-cn/stable/tutorial/nn_with_spu.html code is in this url

################
################
#load the data
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer

def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], None
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

################
################
# define the model. The model is compose as 4 layers mlp and relu activation
from typing import Sequence
import flax.linen as nn

FEATURES = [30, 15, 8, 1]

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

################
################
#define the method of training
import jax.numpy as jnp
import jax

def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)

def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)

def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for x, y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
            )
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params

def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

################
################
#use AUC as the validate metric
from sklearn.metrics import roc_auc_score

def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

################
################
#train a planttext model
import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)

# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01

# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

################
################
#use spu to train model
import spu.utils.distributed as ppd
import spu.spu_pb2 as spu_pb2
import argparse
import json

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
    "-c", "--config", default="/home/whty/CC/cc_test/2pc.json"
)
parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--batch', default=1, type=int, metavar='N', help='batchsize (default: 64)')
args = parser.parse_args()
with open(args.config, 'r') as file:
    conf = json.load(file)
ppd.init(conf["nodes"], conf["devices"])

# Check the version of your SecretFlow
#print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
#sf.shutdown()

#sf.init(['alice', 'bob'], address='local')

#alice, bob = sf.PYU('alice'), sf.PYU('bob')
#spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
#cc_add
alice, bob = ppd.device("P1"),ppd.device("P2")
spu = ppd.device("SPU")

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)

## cc_add
def pyu_to_spu(x1):
    return x1

device = spu
#x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device) cc_add
x1_, x2_, y_ = spu(pyu_to_spu)(x1),spu(pyu_to_spu)(x2),spu(pyu_to_spu)(y)
#init_params_ = ppd.device.to(alice, init_params).to(device)
init_params_pyu = alice(pyu_to_spu)(init_params)
init_params_ = spu(pyu_to_spu)(init_params_pyu)

'''
params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
'''

################
################
#check the params
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = ppd.get(params_spu)
print(params)

################
################
#we need to validate the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

################
################

Relevant log output

terminal output as:
Traceback (most recent call last):
  File "/home/whty/CC/cc_test/spu_nn_examples.py", line 202, in <module>
    params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 693, in __call__
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 693, in <listcomp>
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 247, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 240, in _call
    raise Exception("remote exception", result)
Exception: ('remote exception', Exception('Traceback (most recent call last):\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 326, in Run\n    ret_objs = fn(self, *args, **kwargs)\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 589, in builtin_spu_run\n    rt.run(spu_exec)\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/api.py", line 44, in run\n    return self._vm.Run(executable.SerializeToString())\nRuntimeError: what: \n\t[external/yacl/yacl/link/transport/channel.cc:427] Get data timeout, key=root-0:P2P-10510:1->0\nstacktrace: \n#0 yacl::link::Context::RecvInternal()+0x773196ae946b\n#1 yacl::link::Context::Recv()+0x773196aeab96\n#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x773195e4050e\n#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x773195e457ac\n#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x773195e45da1\n#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x773195e45ef2\n#6 std::_Function_handler<>::_M_invoke()+0x773195e11d68\n#7 std::__future_base::_State_baseV2::_M_do_set()+0x773195c43b52\n#8 (unknown)+0x7731d1699ee8\n\n\n'))
TY-cc commented 3 weeks ago

And i use python nodectl.py -c 2pc.json start --node_id node:1 &> node1.log & to start node0 and node1. it report:

[2024-09-26 14:29:27,927] [MainProcess] Traceback (most recent call last):
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 326, in Run
    ret_objs = fn(self, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 589, in builtin_spu_run
    rt.run(spu_exec)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/api.py", line 44, in run
    return self._vm.Run(executable.SerializeToString())
RuntimeError: what: 
    [external/yacl/yacl/link/transport/channel.cc:427] Get data timeout, key=root-0:P2P-10510:1->0
stacktrace: 
#0 yacl::link::Context::RecvInternal()+0x773196ae946b
#1 yacl::link::Context::Recv()+0x773196aeab96
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x773195e4050e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x773195e457ac
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x773195e45da1
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x773195e45ef2
#6 std::_Function_handler<>::_M_invoke()+0x773195e11d68
#7 std::__future_base::_State_baseV2::_M_do_set()+0x773195c43b52
#8 (unknown)+0x7731d1699ee8
tpppppub commented 3 weeks ago

It might be a node crash due to OOM. You can try with less data volume or use a server with a larger RAM.

TY-cc commented 3 weeks ago

Why can it run successfully the first time?

tpppppub commented 3 weeks ago

If you interrupt the training program, the runtime is still running. Please wait the runtime finish the training tasks and then run your training program again. Or kill/relaunch the runtime, then run your training program.

TY-cc commented 3 weeks ago

It report new problems.

Traceback (most recent call last):
  File "/home/whty/CC/cc_test/spu_nn_examples.py", line 162, in <module>
    ppd.init(conf["nodes"], conf["devices"])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 1178, in init
    _CONTEXT = HostContext(nodes_def, devices_def)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 1098, in __init__
    self.devices[name] = SPU(
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 1013, in __init__
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 1013, in <listcomp>
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 247, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 236, in _call
    rsp_data = rebuild_messages(rsp_itr.data for rsp_itr in rsp_gen)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 214, in rebuild_messages
    return b''.join([msg for msg in msgs])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 214, in <listcomp>
    return b''.join([msg for msg in msgs])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/distributed_impl.py", line 236, in <genexpr>
    rsp_data = rebuild_messages(rsp_itr.data for rsp_itr in rsp_gen)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/grpc/_channel.py", line 543, in __next__
    return self._next()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/grpc/_channel.py", line 969, in _next
    raise self
grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
        status = StatusCode.UNAVAILABLE
        details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:61320: Failed to connect to remote host: connect: Connection refused (111)"
        debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:61320: Failed to connect to remote host: connect: Connection refused (111)", grpc_status:14, created_time:"2024-09-26T16:41:19.208832718+08:00"}"
tpppppub commented 3 weeks ago

Make sure you have all nodes up.

TY-cc commented 3 weeks ago

yeah, the nodes not start. Thanks your reply!

TY-cc commented 3 weeks ago

Another question, why is AUC different? In plaintext, AUC is auc=0.9927939731411726 In SPU,AUC is auc=0.9954143465443825. But it is equal at examples code