Closed deevashwer closed 1 year ago
Hi @deevashwer
We do not support GPU at this moment. But we do have threading based DLP and ILP support.
Hi @anakinxc,
When I run the REF2K simulation backend with GPT2, I see that it spawns num_vcpu many threads, but the CPU utilization is still quite low (~200%) even though I'm using a machine with 44 vCPUs. I also tried increasing the batch size, and the utilization doesn't get much better. In the PUMA paper, it says that you used 128 threads for the ABY3 evaluation of LLaMA. Is it the case that better parallelism support exists for (distributed, ABY3), but not (simulation, REF2K)? One way I tried to get around this issue was by implementing parallelism at the application level (outside SPU). Basically, I spawned multiple threads, each running its own SPU instance, but this approach ran into the following errors:
File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
executable, output = spu_fe.compile(
File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/frontend.py", line 150, in compile
ir_text, output = _jax_compilation(
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 739, in wrapper
cache[k] = v
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 217, in __setitem__
cache_setitem(self, key, value)
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 79, in __setitem__
self.popitem()
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 231, in popitem
return (key, self.pop(key))
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 116, in pop
raise KeyError(key)
KeyError: "8762534634115-()-None-[(dtype('float32'), (3200, 3200)), (dtype('float32'), (1, 385, 3200))]"
and
File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
executable, output = spu_fe.compile(
File "/opt/conda/envs/spu/lib/python3.10/site-packages/spu/utils/frontend.py", line 150, in compile
ir_text, output = _jax_compilation(
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 739, in wrapper
cache[k] = v
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 217, in __setitem__
cache_setitem(self, key, value)
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 79, in __setitem__
self.popitem()
File "/opt/conda/envs/spu/lib/python3.10/site-packages/cachetools/__init__.py", line 227, in popitem
key = next(iter(self.__order))
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: OrderedDict mutated during iteration
It seems that SPU doesn't support concurrent execution at the moment. Is it easy to add this support (perhaps through replication)?
Hi @deevashwer
Before trying multi spu instances, there are some runtime configs you can try first.
Like these
experimental_disable_mmul_split: True,
experimental_enable_inter_op_par: True,
experimental_enable_intra_op_par: True,
I tried these options, but unfortunately, they are not leading to much improvement. It could be because I'm invoking SPU layer-by-layer on the network, as opposed to at once. To get the best parallelism, it will be nice to implement parallelism at the application level. Could you please help me implement multi-spu instance support?
Hi @deevashwer
For simulation this should be relatively easy. Just instantiate multiply simulators should be fine.
I tried that while ensuring that each thread had a local instance of the simulator, but I still encountered the two errors I mentioned above. It seems that there is some shared state among the different SPU instances which is accessed by the _jax_compilation function.
I tried that while ensuring that each thread had a local instance of the simulator, but I still encountered the two errors I mentioned above. It seems that there is some shared state among the different SPU instances which is accessed by the _jax_compilation function.
I'll take a look
Hi @deevashwer
I created a very simple example which seems work fine.
Here is the code
if __name__ == "__main__":
"""
You can modify the code below for debug purpose only.
Please DONT commit it unless it will cause build break.
"""
sim = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)
sim1 = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)
x = np.random.randn(3, 4)
y = np.random.randn(3, 4)
fn = lambda x, y: x + y
fn1 = lambda x, y: x - y
def run_sim(sim, fn, x, y):
spu_fn = ppsim.sim_jax(sim, fn)
z = spu_fn(x, y)
print(f"spu out = {z}")
# creating thread
t1 = threading.Thread(target=run_sim, args=(sim, fn, x, y))
t2 = threading.Thread(target=run_sim, args=(sim1, fn1, x, y))
t1.start()
t2.start()
t1.join()
t2.join()
print("Done")
Yeah, the error has only shown up for me when there are many invocations to SPU (in the order of thousands). So either in case of many threads or few threads with many invocations per thread. Here's an example that reproduces the error:
import jax
from jax import random, numpy as jnp
from flax import linen as nn
import spu.utils.simulation as pps
import spu
import threading
protocol = spu.ProtocolKind.REF2K
field = spu.FieldType.FM64
config = spu.RuntimeConfig(protocol=protocol, field=field)
# pure function
def dense(params, x):
return jax.lax.dot_general(x, params['params']['kernel'], (((x.ndim - 1,), (0,)), ((), ())),)
class LinearModel(nn.Module):
features: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer = nn.Dense(self.features, use_bias=False, dtype=self.dtype, param_dtype=self.dtype)
def __call__(self, x):
params = {"params": self.layer.variables['params']}
y = self.layer.apply(params, x)
local = threading.local()
local.simulator = pps.Simulator(1, config)
spu_dense = pps.sim_jax(local.simulator, dense)
spu_y = spu_dense(params, x)
spu_dense = pps.sim_jax(local.simulator, self.layer.apply)
spu_y = spu_dense(params, x)
return spu_y
inp_len = 10000
features = 500
batch = 400
def get_params(key, features):
return {'params': {'layer': {'kernel': random.normal(key, (inp_len, features))}}}
def get_output(model, params, x, results, thread_idx):
output = model.apply(params, x)
results[thread_idx] = output
if __name__ == "__main__":
models = [LinearModel(features=(features + i)) for i in range(batch)]
key = random.PRNGKey(0)
params = [get_params(key, features + i) for i in range(batch)]
x = random.normal(key, (inp_len,))
results = [None] * batch
workers = [threading.Thread(target=get_output, args=(models[i], params[i], x, results, i)) for i in range(batch)]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
There are smaller examples as well, but this is the one that consistently reproduces the error on my i9 Macbook Pro.
Hi @deevashwer
I can reproduce this on my end as well...Will investigate this.
Thanks for the repro :D
Does the REF2K backend support parallelism and GPU acceleration? If there's parallelism, does it occur at the inference level (i.e., multiple threads accelerate single inference) or the batch level (i.e., each thread handles an independent inference)?