[Bug]: Report RuntimeError when i use SPU to nn training #869

Closed TY-cc closed 3 weeks ago

TY-cc commented 3 weeks ago

SPU runtime

linux ubuntu 22.04

Python 3.10.4

GCC 11.4

Current Behavior?

I am using spu for neural network training. When I first executed it, the source code could run normally. The second time I ran it, I forcibly interrupted the execution of the source code without completing the training. The third time I run it, it report a RuntimeError. code as follow:

Standalone code to reproduce the issue

# code is in this url

#load the data
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer

def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], None
                return x_train[:, 15:], y_train
            return x_train, y_train
        return x_test, y_test

# define the model. The model is compose as 4 layers mlp and relu activation
from typing import Sequence
import flax.linen as nn

FEATURES = [30, 15, 8, 1]

class MLP(nn.Module):
    features: Sequence[int]

    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

#define the method of training
import jax.numpy as jnp
import jax

def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)

def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)

def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for x, y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params

def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

#use AUC as the validate metric
from sklearn.metrics import roc_auc_score

def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

#train a planttext model
import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)

# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01

# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)

#use spu to train model
import spu.utils.distributed as ppd
import spu.spu_pb2 as spu_pb2
import argparse
import json

parser = argparse.ArgumentParser(description='distributed driver.')
    "-c", "--config", default="/home/whty/CC/cc_test/2pc.json"
parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--batch', default=1, type=int, metavar='N', help='batchsize (default: 64)')
args = parser.parse_args()
with open(args.config, 'r') as file:
    conf = json.load(file)
ppd.init(conf["nodes"], conf["devices"])

# Check the version of your SecretFlow
#print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.

#sf.init(['alice', 'bob'], address='local')

#alice, bob = sf.PYU('alice'), sf.PYU('bob')
#spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
alice, bob = ppd.device("P1"),ppd.device("P2")
spu = ppd.device("SPU")

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)

## cc_add
def pyu_to_spu(x1):
    return x1

device = spu
#x1_, x2_, y_ =,, cc_add
x1_, x2_, y_ = spu(pyu_to_spu)(x1),spu(pyu_to_spu)(x2),spu(pyu_to_spu)(y)
#init_params_ =, init_params).to(device)
init_params_pyu = alice(pyu_to_spu)(init_params)
init_params_ = spu(pyu_to_spu)(init_params_pyu)

params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size

#check the params
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = ppd.get(params_spu)

#we need to validate the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)


Relevant log output

terminal output as:
Traceback (most recent call last):
  File "/home/whty/CC/cc_test/", line 202, in <module>
    params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 693, in __call__
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 693, in <listcomp>
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 451, in result
    return self.__get_result()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 403, in __get_result
    raise self._exception
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 247, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 240, in _call
    raise Exception("remote exception", result)
Exception: ('remote exception', Exception('Traceback (most recent call last):\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 326, in Run\n    ret_objs = fn(self, *args, **kwargs)\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 589, in builtin_spu_run\n\n  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/", line 44, in run\n    return self._vm.Run(executable.SerializeToString())\nRuntimeError: what: \n\t[external/yacl/yacl/link/transport/] Get data timeout, key=root-0:P2P-10510:1->0\nstacktrace: \n#0 yacl::link::Context::RecvInternal()+0x773196ae946b\n#1 yacl::link::Context::Recv()+0x773196aeab96\n#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x773195e4050e\n#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x773195e457ac\n#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x773195e45da1\n#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x773195e45ef2\n#6 std::_Function_handler<>::_M_invoke()+0x773195e11d68\n#7 std::__future_base::_State_baseV2::_M_do_set()+0x773195c43b52\n#8 (unknown)+0x7731d1699ee8\n\n\n'))
TY-cc commented 3 weeks ago

And i use python -c 2pc.json start --node_id node:1 &> node1.log & to start node0 and node1. it report:

[2024-09-26 14:29:27,927] [MainProcess] Traceback (most recent call last):
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 326, in Run
    ret_objs = fn(self, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 589, in builtin_spu_run
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/", line 44, in run
    return self._vm.Run(executable.SerializeToString())
RuntimeError: what: 
    [external/yacl/yacl/link/transport/] Get data timeout, key=root-0:P2P-10510:1->0
#0 yacl::link::Context::RecvInternal()+0x773196ae946b
#1 yacl::link::Context::Recv()+0x773196aeab96
#2 spu::mpc::cheetah::CheetahDot::Impl::doDotOLESenderRecvStep()+0x773195e4050e
#3 spu::mpc::cheetah::CheetahDot::Impl::doDotOLE()+0x773195e457ac
#4 spu::mpc::cheetah::CheetahDot::Impl::DotOLE()+0x773195e45da1
#5 spu::mpc::cheetah::CheetahDot::DotOLE()+0x773195e45ef2
#6 std::_Function_handler<>::_M_invoke()+0x773195e11d68
#7 std::__future_base::_State_baseV2::_M_do_set()+0x773195c43b52
#8 (unknown)+0x7731d1699ee8
tpppppub commented 3 weeks ago

It might be a node crash due to OOM. You can try with less data volume or use a server with a larger RAM.

TY-cc commented 3 weeks ago

Why can it run successfully the first time?

tpppppub commented 3 weeks ago

If you interrupt the training program, the runtime is still running. Please wait the runtime finish the training tasks and then run your training program again. Or kill/relaunch the runtime, then run your training program.

TY-cc commented 3 weeks ago

It report new problems.

Traceback (most recent call last):
  File "/home/whty/CC/cc_test/", line 162, in <module>
    ppd.init(conf["nodes"], conf["devices"])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 1178, in init
    _CONTEXT = HostContext(nodes_def, devices_def)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 1098, in __init__
    self.devices[name] = SPU(
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 1013, in __init__
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 1013, in <listcomp>
    results = [future.result() for future in futures]
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 451, in result
    return self.__get_result()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 403, in __get_result
    raise self._exception
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/concurrent/futures/", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 247, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 236, in _call
    rsp_data = rebuild_messages( for rsp_itr in rsp_gen)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 214, in rebuild_messages
    return b''.join([msg for msg in msgs])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 214, in <listcomp>
    return b''.join([msg for msg in msgs])
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/spu/utils/", line 236, in <genexpr>
    rsp_data = rebuild_messages( for rsp_itr in rsp_gen)
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/grpc/", line 543, in __next__
    return self._next()
  File "/home/whty/anaconda3/envs/spu/lib/python3.10/site-packages/grpc/", line 969, in _next
    raise self
grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
        status = StatusCode.UNAVAILABLE
        details = "failed to connect to all addresses; last error: UNKNOWN: ipv4: Failed to connect to remote host: connect: Connection refused (111)"
        debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4: Failed to connect to remote host: connect: Connection refused (111)", grpc_status:14, created_time:"2024-09-26T16:41:19.208832718+08:00"}"
tpppppub commented 3 weeks ago

Make sure you have all nodes up.

TY-cc commented 3 weeks ago

yeah, the nodes not start. Thanks your reply!

TY-cc commented 3 weeks ago

Another question, why is AUC different? In plaintext, AUC is auc=0.9927939731411726 In SPU,AUC is auc=0.9954143465443825. But it is equal at examples code