analytics-zoo / analytics-zoo.github.io

Apache License 2.0
8 stars 9 forks source link

Trusted Deep Learning Toolkit: distributed pytorch on k8s-sgx #4

Open gc-fu opened 1 year ago

gc-fu commented 1 year ago

As a ppml developer, I want to run distributed PyTorch on k8s, with SGX support (using with Gramine), in order to provide a security environment for our customer to run trusted deep learning applications.

PyTorch modifications

PyTorch modifications Fix Additional Info
hipGetDeviceCount() raise error when no GPU is available Apply patch to PyTorch to suppress the error https://github.com/pytorch/pytorch/pull/80405 This problem has been fixed in PyTorch 1.13.0.

Gramine modifications

Gramine modifications Fix Additional Info
Return error when using MSG_MORE flag Apply patch to Gramine to ignore this flag https://github.com/analytics-zoo/gramine/pull/6
ioctl has problems Apply patch https://github.com/analytics-zoo/gramine/pull/7
Gramine cannot handle signal correctly Need fix later https://github.com/gramineproject/gramine/issues/1034

GLOO modifications

GLOO modifications Fix Additional Info
getifaddrs use unsupported socket domain NETLINK in Gramine Apply patch to GLOO to use env variable to acquire network interface to use. https://github.com/analytics-zoo/gloo/pull/1

Other modifications

Other modifications Fix Additional Info
Enable runtime domain configuration in Gramine Add sys.enable_extra_runtime_domain_names_conf = true in manifest None.
datasets package provided by huggingface using flock filelock Apply patch to use fnctl file lock Gramine only supports fnctl file lock
pyarrow error because of aws version Downgrade PyArrow to 6.0.1
Insufficient memory when doing training Increase SGX_MEM_SIZE This is closely related to the batch size. Generally, larger batch size requires larger SGX_MEM_SIZE. After we have EDMM, this can be ignored
The default k8s cpu management policy is share Change cpu management policy to use static according to here A related issue
Change topology manager policy Change cpu topology policy to best-effort or single-numa-node according to here Ref
Disable Hyper-threading for servers Disable hyper-threading in server and config kubelet accordingly The use of hyper-threading may cause security problems such as side-channel attack. Therefore, it is not supported by Gramine

Hyper-threading It seems that the use of Hyper-threading will also have impacts on the performance on the distributed training. In native mode, Kubernetes will try to allocate logical cores on the same physical cores. For instance, if the user request 24 logical cores and each physical cores have two threads, then the 24 logical cores will be distributed onto 12 physical cores by default.

This behavior is described here

In SGX mode with Gramine Libos, the Hyper-threading seems have no functionality. We try to allocate 12 logical cores on 6 physical cores. However, as the result, only 6 logical cores can function. A comparison can be seen by the following two figures:

Image

Image

Enable TCP_TLS during computation

Check here

Optimization

Intel-openmp

The Intel-openmp is currently not supported in SGX mode. The related error and Gramine issue:

openat(AT_FDCWD, "/dev/shm/__KMP_REGISTERED_LIB_1_0", O_RDWR|O_CREAT|O_EXCL|0xa0000, 0666) = -2 https://github.com/gramineproject/gramine/pull/827

Gramine-patched openmp

As recommended here, this patched OpenMP can bring better performance in the SGX enclaves.

However, after setting this in LD_PRELOAD, the PyTorch training in native mode will get segmentation fault.

Besides, we have to set this LD_PRELOAD environment variable in bash.manifest.template, which means that we cannot change this argument in the image after the image is built.

Intel PyTorch extension

Ipex has almost no acceleration effect on pert training, and in sgx mode it will cause errors due to fork.

bash: warning: setlocale: LC_ALL: cannot change locale (C.UTF-8)
 Illegal instruction (core dumped)
/usr/lib/python3/dist-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!
  RequestsDependencyWarning)
Traceback (most recent call last):
  File "/ppml/examples/pert_ipex.py", line 12, in <module>
    import intel_extension_for_pytorch as ipex
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/__init__.py", line 24, in <module>
    from . import cpu
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/__init__.py", line 2, in <module>
    from . import runtime
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/__init__.py", line 3, in <module>
    from .multi_stream import MultiStreamModule, get_default_num_streams, \
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/multi_stream.py", line 43, in <module>
    class MultiStreamModule(nn.Module):
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/multi_stream.py", line 90, in MultiStreamModule
    cpu_pool: CPUPool = CPUPool(node_id=0),
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/cpupool.py", line 32, in __init__
    self.core_ids = get_core_list_of_node_id(node_id)
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/runtime_utils.py", line 20, in get_core_list_of_node_id
    num_of_nodes = get_num_nodes()
  File "/usr/local/lib/python3.7/dist-packages/intel_extension_for_pytorch/cpu/runtime/runtime_utils.py", line 4, in get_num_nodes
    return int(subprocess.check_output('lscpu | grep Socket | awk \'{print $2}\'', shell=True))

jemalloc

jemalloc can bring better performance in native mode.

In sgx mode, it will cause the training speed to gradually slow down, and eventually it will be slower than not using jemalloc.

After applying jemalloc and intel-openmp in native mode, the execution time for 4 nodes reduced from 5450s to 4125s.

Test cases and test data

Test case 0

无任何优化 环境:docker,native 模式,启用tls baseline:无 使用4节点分布式,6w数据

5450.44s 10.5797

Test case 1

测试jemalloc 和 intel-omp环境变量参数带来的性能提升 环境:docker,native 模式,启用tls baseline: test case 0中的测试结果 使用4节点分布式,6w数据

4125.75s 13.9611

Test case 2

测试jemalloc单独带来的性能提升 环境:docker, sgx 模式,启用tls baseline: 之前在k8s下测得的性能数据 使用4节点分布式,6w数据

7115.37s 8.1436

出现了性能degradation,训练会越来越慢,k8s下的训练一轮的时间大概是8500s

Test case 3

测试jemalloc和gramine patched openmp带来的性能提升 环境:docker sgx模式,启用tls baseline: test case 2 中测得的性能数据 使用4节点分布式,6w数据

7015.12s 8.24 与test case 2 一样出现了性能degradation,训练会越来越慢。

Test case 4

使用openmp 参数
export OMP_SCHEDULE=STATIC export OMP_PROC_BIND=CLOSE 环境:docker,sgx 模式,启用tls baseline:k8s 测试数据。 使用4节点分布式,6w数据

8520.7s 6.76

Test case 5

使用梯度累积,四个batch后做一个all_reduce 环境:docker,sgx 模式,启用tls baseline:k8s 测试数据。 使用4节点分布式,6w数据

6749.17s 8.53467

Perf

The performance data can be acquired here

jason-dai commented 1 year ago

Have you reproduced the performance result of https://github.com/analytics-zoo/bytedance-pytorch-examples/pull/6/files#diff-b335630551682c19a781afebcf4d07bf978fb1f8ac04c6bf87428ed5106870f5R269 in native mode?

gc-fu commented 1 year ago

Have you reproduced the performance result of https://github.com/analytics-zoo/bytedance-pytorch-examples/pull/6/files#diff-b335630551682c19a781afebcf4d07bf978fb1f8ac04c6bf87428ed5106870f5R269 in native mode?

I did this experiment before using the same settings, but didn't get the same speedup ratio on the ICX machine (the speedup ratio would be lower compared to the data listed here), probably due to different CPU settings. I will redo this experiment today and add the detailed data to this issue.

gc-fu commented 1 year ago

Have you reproduced the performance result of https://github.com/analytics-zoo/bytedance-pytorch-examples/pull/6/files#diff-b335630551682c19a781afebcf4d07bf978fb1f8ac04c6bf87428ed5106870f5R269 in native mode?

I have tried this experiment again. It turns out that Yang's result is acquired by using 80 physical cores. And I have reproduced this result 1 Node (local) optimized in a container within ICX machine. The original data listed in the table is acquired from an experiment using only 13 physical cores, which explains why I did not get the same speedup in my experiment.

An interesting phenomenon is that if you use a large amount of CPU and don't use any optimization options for either OpenMp or Intel-OpenMp, it will run slower than if you use a small amount of CPU (also don't use any optimization options).

Therefore, I think a correct comparison to calculate speedup ratio is to use OpenMp and the corresponding optimization options versus using Intel-OpenMp and jemalloc together with optimization options from bigdl-nano.

I will provide detailed data later in this repo.