secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
237 stars 103 forks source link

REF2K: Parallelism and GPU support #307

Closed deevashwer closed 1 year ago

deevashwer commented 1 year ago

Does the REF2K backend support parallelism and GPU acceleration? If there's parallelism, does it occur at the inference level (i.e., multiple threads accelerate single inference) or the batch level (i.e., each thread handles an independent inference)?

anakinxc commented 1 year ago

Hi @deevashwer

We do not support GPU at this moment. But we do have threading based DLP and ILP support.

deevashwer commented 1 year ago

Hi @anakinxc,

When I run the REF2K simulation backend with GPT2, I see that it spawns num_vcpu many threads, but the CPU utilization is still quite low (~200%) even though I'm using a machine with 44 vCPUs. I also tried increasing the batch size, and the utilization doesn't get much better. In the PUMA paper, it says that you used 128 threads for the ABY3 evaluation of LLaMA. Is it the case that better parallelism support exists for (distributed, ABY3), but not (simulation, REF2K)? One way I tried to get around this issue was by implementing parallelism at the application level (outside SPU). Basically, I spawned multiple threads, each running its own SPU instance, but this approach ran into the following errors:

  File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
    executable, output = spu_fe.compile(
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/frontend.py", line 150, in compile
    ir_text, output = _jax_compilation(
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 739, in wrapper
    cache[k] = v
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 217, in __setitem__
    cache_setitem(self, key, value)
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 79, in __setitem__
    self.popitem()
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 231, in popitem
    return (key, self.pop(key))
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 116, in pop
    raise KeyError(key)
KeyError: "8762534634115-()-None-[(dtype('float32'), (3200, 3200)), (dtype('float32'), (1, 385, 3200))]"

and

 File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
    executable, output = spu_fe.compile(
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/frontend.py", line 150, in compile
    ir_text, output = _jax_compilation(
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 739, in wrapper
    cache[k] = v
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 217, in __setitem__
    cache_setitem(self, key, value)
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 79, in __setitem__
    self.popitem()
  File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 227, in popitem
    key = next(iter(self.__order))
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: OrderedDict mutated during iteration

It seems that SPU doesn't support concurrent execution at the moment. Is it easy to add this support (perhaps through replication)?

anakinxc commented 1 year ago

Hi @deevashwer

Before trying multi spu instances, there are some runtime configs you can try first.

Like these

experimental_disable_mmul_split: True,
experimental_enable_inter_op_par: True,
experimental_enable_intra_op_par: True,
deevashwer commented 1 year ago

I tried these options, but unfortunately, they are not leading to much improvement. It could be because I'm invoking SPU layer-by-layer on the network, as opposed to at once. To get the best parallelism, it will be nice to implement parallelism at the application level. Could you please help me implement multi-spu instance support?

anakinxc commented 1 year ago

Hi @deevashwer

For simulation this should be relatively easy. Just instantiate multiply simulators should be fine.

deevashwer commented 1 year ago

I tried that while ensuring that each thread had a local instance of the simulator, but I still encountered the two errors I mentioned above. It seems that there is some shared state among the different SPU instances which is accessed by the _jax_compilation function.

anakinxc commented 1 year ago

I tried that while ensuring that each thread had a local instance of the simulator, but I still encountered the two errors I mentioned above. It seems that there is some shared state among the different SPU instances which is accessed by the _jax_compilation function.

I'll take a look

anakinxc commented 1 year ago

Hi @deevashwer

I created a very simple example which seems work fine.

Here is the code

if __name__ == "__main__":
    """
    You can modify the code below for debug purpose only.
    Please DONT commit it unless it will cause build break.
    """

    sim = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)
    sim1 = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)

    x = np.random.randn(3, 4)
    y = np.random.randn(3, 4)
    fn = lambda x, y: x + y
    fn1 = lambda x, y: x - y

    def run_sim(sim, fn, x, y):
        spu_fn = ppsim.sim_jax(sim, fn)
        z = spu_fn(x, y)
        print(f"spu out = {z}")

    # creating thread
    t1 = threading.Thread(target=run_sim, args=(sim, fn, x, y))
    t2 = threading.Thread(target=run_sim, args=(sim1, fn1, x, y))

    t1.start()
    t2.start()

    t1.join()
    t2.join()

    print("Done")
deevashwer commented 1 year ago

Yeah, the error has only shown up for me when there are many invocations to SPU (in the order of thousands). So either in case of many threads or few threads with many invocations per thread. Here's an example that reproduces the error:

import jax
from jax import random, numpy as jnp
from flax import linen as nn
import spu.utils.simulation as pps
import spu
import threading

protocol = spu.ProtocolKind.REF2K
field = spu.FieldType.FM64
config = spu.RuntimeConfig(protocol=protocol, field=field)

# pure function
def dense(params, x):
    return jax.lax.dot_general(x, params['params']['kernel'], (((x.ndim - 1,), (0,)), ((), ())),)

class LinearModel(nn.Module):
    features: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.layer = nn.Dense(self.features, use_bias=False, dtype=self.dtype, param_dtype=self.dtype)

    def __call__(self, x):
        params = {"params": self.layer.variables['params']}

        y = self.layer.apply(params, x)
        local = threading.local()
        local.simulator = pps.Simulator(1, config)

        spu_dense = pps.sim_jax(local.simulator, dense)
        spu_y = spu_dense(params, x)

        spu_dense = pps.sim_jax(local.simulator, self.layer.apply)
        spu_y = spu_dense(params, x)

        return spu_y

inp_len = 10000
features = 500
batch = 400

def get_params(key, features):
    return {'params': {'layer': {'kernel': random.normal(key, (inp_len, features))}}}

def get_output(model, params, x, results, thread_idx):
    output = model.apply(params, x)
    results[thread_idx] = output

if __name__ == "__main__":
    models = [LinearModel(features=(features + i)) for i in range(batch)]

    key = random.PRNGKey(0)
    params = [get_params(key, features + i) for i in range(batch)]
    x = random.normal(key, (inp_len,))

    results = [None] * batch
    workers = [threading.Thread(target=get_output, args=(models[i], params[i], x, results, i)) for i in range(batch)]
    for worker in workers:
        worker.start()
    for worker in workers:
        worker.join()

There are smaller examples as well, but this is the one that consistently reproduces the error on my i9 Macbook Pro.

anakinxc commented 1 year ago

Hi @deevashwer

I can reproduce this on my end as well...Will investigate this.

Thanks for the repro :D

anakinxc commented 1 year ago

Hi @deevashwer

The bug should have been fixed with this change.