google / fully-homomorphic-encryption

An FHE compiler for C++
Apache License 2.0
3.5k stars 252 forks source link

Performance issue about the jaxite #69

Closed dasistwo closed 11 months ago

dasistwo commented 1 year ago

Hi all! I was trying to compare the performance between the tfhe-rs and the Jaxite, expecting that the Jaxite would be way faster than the tfhe-rs as it exploits the GPU, but I found that the Jaxite was too slow than the tfhe-rs. I want to know if my configuration is wrong, or the Jaxite is not fully developed yet.

I've tested with the transpiler of the Jaxite and the tfhe-rs, and used the example of hello_world. I do not use the bazel run when I tested with Jaxite, as bazel run cannot initiate the CUDA. (Seems that the GPU / TPU test was not publicly opened in the bazel as far as I checked in here.) Rather, I just ran directly with the python.

The Jaxite spends about 10000 seconds per evaluation, which was not successful after the first iteration, while the tfhe-rs spends about 30 seconds.

user@gpu05:/home/user/fully-homomorphic-encryption$ python3 transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py
Generating keys
I0909 17:04:13.349542 139736576647680 xla_bridge.py:622] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
I0909 17:04:13.350527 139736576647680 xla_bridge.py:622] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
Quantized x =  [-128  -42   43  127]
Running FHE circuit
FHE circuit took 9849.094140 seconds
f(-128) = 79
Traceback (most recent call last):
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 581, in shaped_abstractify
    return _shaped_abstractify_handlers[type(x)](x)
KeyError: <class 'jax.Array'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 98, in <module>
    app.run(main)
  File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 93, in main
    quantized_result = jnp.append(quantized_result, result_cleartext)
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 477, in common_infer_params
    avals.append(shaped_abstractify(a))
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 583, in shaped_abstractify
    return _shaped_abstractify_slow(x)
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/api_util.py", line 572, in _shaped_abstractify_slow
    raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'jax.Array'> as an abstract array; it does not have a dtype attribute

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 98, in <module>
    app.run(main)
  File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/fully-homomorphic-encryption/transpiler/tensorflow/examples/hello_world/hello_world_testbench_python.py", line 93, in main
    quantized_result = jnp.append(quantized_result, result_cleartext)
TypeError: Cannot interpret value of type <class 'jax.Array'> as an abstract array; it does not have a dtype attribute
user@gpu05:/home/user/fully-homomorphic-encryption$ ./bazel-bin/transpiler/tensorflow/examples/hello_world/./hello_world_testbench
Inferring sine for 0                                                                                                                              FHE computation in 33 s
Sine value: -0.01                                                                                                                                 Inferring sine for 1.5707964
FHE computation in 33 s                                                                                                                           Sine value: 0.99
Inferring sine for 3.1415927
FHE computation in 32 s                                                                                                                           Sine value: -0.01
Inferring sine for 4.712389
FHE computation in 32 s
Sine value: -1.09                                                                                                                                 Inferring sine for 6.2831855
FHE computation in 32 s
Sine value: -0.12

Both codes were based on the same netlist file, which means that they went through the same step but at the very end with different transpilers.

heir-opt --heir-tosa-to-arith ${INPUT_TOSA} | tee >(heir-translate --emit-metadata -o ${OUTPUT_METADATA}) |  heir-translate --emit-verilog -o ${OUTPUT_VERILOG}
YOSYS_SCRIPT="read_verilog ${OUTPUT_VERILOG}; hierarchy -check -top main; techmap; opt; splitnets -ports for_*; abc -lut 3; opt_clean -purge; techmap -map ${LUTMAP_SCRIPT}; opt_clean -purge; flatten; hierarchy -generate lut3 o:Y i:P* i:A i:B i:C; opt_expr; opt; opt_clean -purge; rename -hide */w:*; rename -enumerate */w:*; rename -top ${MODEL_NAME}; clean; write_verilog -noattr ${OUTPUT_NETLIST}"
yosys -p "${YOSYS_SCRIPT}"
# transpiler is exported to the tfhe-rs transpiler
transpiler --ir_path ${OUTPUT_NETLIST} --liberty_path ${LIBERTY_CELLS} --heir_metadata_path ${OUTPUT_METADATA} --parallelism=0 --rs_out ${OUTPUT_RUST}
# transpiler is exported to the jaxite transpiler
transpiler --ir_path  ${OUTPUT_NETLIST} --optimizer=yosys --liberty_path ${LIBERTY_CELLS} --metadata_path ${OUTPUT_METADATA} --parallelism=0 --py_out ${OUTPUT_PY}

This is the code that I've used as a testbench for the Jaxite.

"""A jaxite testbench for hello_world tensorflow code."""

from collections.abc import Sequence
import functools

from absl import app
from jaxite.jaxite_bool import bool_params
from jaxite.jaxite_bool import jaxite_bool
from jax import Array as ndarray
from jax import numpy as jnp
import timeit

from transpiler.tensorflow.examples.hello_world import hello_world_fhe_lib_python

def bit_slice_to_int(bit_slice: list[bool]) -> int:
  """Given an list of bits, return a base-10 integer."""
  result = 0
  for i, bit in enumerate(bit_slice):
    result |= int(bit) << i
  return result

def int_to_bit_slice(input_int: int) -> list[bool]:
  """Given an integer and bit width, return a bitwise representation."""
  result: list[bool] = [False] * 8
  for i in range(8):
    result[i] = ((input_int >> i) & 1) != 0
  return result

def quantize(arr: ndarray) -> ndarray:
    """
    Quantize an array of jnp.float32 to jnp.int8.

    Args:
    - arr (jnp.ndarray): Input array of jnp.float32.

    Returns:
    - jnp.ndarray: Quantized array of jnp.int8.
    """
    return ((arr / 0.024480115622282) - 128.0).astype(jnp.int8)

def dequantize(arr: ndarray) -> ndarray:
    """
    Dequantize an array of jnp.int8 to jnp.float32.

    Args:
    - arr (jnp.ndarray): Input array of jnp.int8.

    Returns:
    - jnp.ndarray: Dequantized array of jnp.float32.
    """
    return ((arr.astype(jnp.float32) - 5) * 0.00829095672816038)

@functools.cache
def setup():
  print(f'Generating keys')
  boolean_params = bool_params.get_params_for_128_bit_security()
  lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(1)
  rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(1)
    # lwe_dimension=800,
    # rlwe_dimension=2,
    # plaintext_modulus=2^32,
    # polynomial_modulus_degree=512,
    # bsk log_base=4, level_count=6
    # ksk log_base=4, level_count=5
  cks = jaxite_bool.ClientKeySet(boolean_params, lwe_rng, rlwe_rng)
  sks = jaxite_bool.ServerKeySet(cks, boolean_params, lwe_rng, rlwe_rng)
  return (boolean_params, lwe_rng, cks, sks)

def main(argv: Sequence[str]) -> None:
  del argv
  (boolean_params, lwe_rng, cks, sks) = setup()
  pi = 3.14159265358979323846
  x_vals = jnp.float32(jnp.linspace(0, 2.0*pi, 4))
  quantized_x_vals = quantize(x_vals)
  quantized_result = ndarray() # type: ignore
  print("Quantized x = ", quantized_x_vals)
  for x in quantized_x_vals:
    x_cleartext = int_to_bit_slice(x)
    x_ciphertext = [jaxite_bool.encrypt(z, cks, lwe_rng) for z in x_cleartext]
    print('Running FHE circuit')
    start = timeit.default_timer()
    result_ciphertext = hello_world_fhe_lib_python.hello_world(
        x_ciphertext,
        sks,
        boolean_params,
    )
    end = timeit.default_timer()
    print(f'FHE circuit took {end - start:1f} seconds')
    result_ciphertext = [jaxite_bool.decrypt(z, cks) for z in result_ciphertext]
    result_cleartext = bit_slice_to_int(result_ciphertext)
    print(f'f({x}) = {result_cleartext}')
    quantized_result = jnp.append(quantized_result, result_cleartext)

  result = dequantize(quantized_result)

if __name__ == '__main__':
  app.run(main)

This is the modified BUILD file to create the py_library

# Hello World example

load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@rules_rust//rust:defs.bzl", "rust_binary", "rust_library")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")

package(
    default_applicable_licenses = ["@com_google_fully_homomorphic_encryption//:license"],
    default_visibility = ["//visibility:public"],
)

licenses(["notice"])

rust_library(
    name = "hello_world_fhe_lib_rust",
    srcs = ["hello_world_fhe_lib_rust.rs"],
    disable_pipelining = True,
    tags = [
        "manual",
        "notap",
    ],
    deps = [
        "@crate_index//:rayon",
        "@crate_index//:tfhe",
    ],
    rustc_flags = ["--cfg", "lut"]
)

rust_binary(
    name = "hello_world_testbench_rust",
    srcs = [
        "hello_world_testbench_rust.rs",
    ],
    tags = [
        "manual",
        "notap",
    ],
    deps = [
        ":hello_world_fhe_lib_rust",
        "@crate_index//:rayon",
        "@crate_index//:tfhe",
    ],
)

py_library(
    name = "hello_world_fhe_lib_python",
    srcs = ["hello_world_fhe_lib_python.py"],
    tags = [
        "manual",
        "notap",
    ],
    deps = [
        "@transpiler_pip_deps//pypi__jaxite",
    ],
)

py_binary(
    name = "hello_world_testbench_python",
    srcs = [
        "hello_world_testbench_python.py",
    ],
    tags = [
        "manual",
        "notap",
    ],
    deps = [
        ":hello_world_fhe_lib_python",
        "@com_google_absl_py//absl:app",
        "@transpiler_pip_deps//pypi__jaxite",
    ],
)

I'm using Python 3.10.13, Nvidia V100 as GPU, and CUDA 11.8. Tell me if my testbench or configuration is wrong.

j2kun commented 1 year ago

Jaxite is not yet performant, and most of our efforts working on it have been to make it performant for TPU architectures. Even then, it does not have performance parity with CPU-parallel tfhe-rs, though we're working on improving it. As a result, you'll see very poor performance. And while improvements to the TPU side should come with improvements to GPU, we're not particularly focused on GPU optimizations at the moment.