zama-ai / concrete-ml

Concrete ML: Privacy Preserving ML framework using Fully Homomorphic Encryption (FHE), built on top of Concrete, with bindings to traditional ML frameworks.
Other
1.03k stars 146 forks source link

Is there a way for using gpu acceleration in the finetune gpt2 with LoRA use case example? #918

Open filippo-merlo opened 1 month ago

filippo-merlo commented 1 month ago

When I try to use compile_model with CUDA as the specified device, I encounter the following error. Is there a way to resolve this, or is the lora.py code not yet compatible with running on a GPU?"

the tutorial I am following: https://github.com/zama-ai/concrete-ml/tree/release/1.7.x/use_case_examples/lora_finetuning

Traceback (most recent call last): File "/lora_finetuning/lorafinetunegpt2.py", line 135, in hybrid_model.compile_model(inputset, n_bits=16, device="cuda") File "/src/concrete/ml/torch/hybrid_model.py", line 516, in compile_model self.private_q_modules[name] = compile_torch_model( ^^^^^^^^^^^^^^^^^^^^ File "/src/concrete/ml/torch/compile.py", line 342, in compile_torch_model return _compile_torch_or_onnx_model( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/src/concrete/ml/torch/compile.py", line 224, in _compile_torch_or_onnx_model quantized_module = build_quantized_module( ^^^^^^^^^^^^^^^^^^^^^^^ File "/src/concrete/ml/torch/compile.py", line 124, in build_quantized_module numpy_model = NumpyModule(model, dummy_input_for_tracing) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/src/concrete/ml/torch/numpy_module.py", line 51, in init ) = get_equivalent_numpy_forward_from_torch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/src/concrete/ml/onnx/convert.py", line 153, in get_equivalent_numpy_forward_from_torch torch.onnx.export( File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export _export( File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1612, in _export graph, params_dict, torch_out = _model_to_graph( ^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph outs = ONNXTracedModule( ^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 138, in forward graph, out = torch._C._create_graph_by_tracing( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 129, in wrapper outs.append(self.inner(trace_inputs)) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward result = self.forward(*input, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

This is the only part of the code I modified:

Create the HybridFHEModel with the specified remote modules

hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)

Prepare input data for calibration

input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) ( tokenizer.vocab_size - 1 ) label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) ( tokenizer.vocab_size - 1 )

input_tensor = input_tensor.to("cuda") label_tensor = label_tensor.to("cuda")

inputset = (input_tensor, label_tensor)

Calibrate and compile the model

hybrid_model.model.toggle_calibrate(enable=True) hybrid_model.compile_model(inputset, n_bits=16, device="cuda")

jfrery commented 1 month ago

Hi @filippo-merlo,

The issue you have comes from a misunderstanding of two separate GPU modes:

  1. Standard torch GPU acceleration (triggered by tensor.to("cuda"))
  2. FHE GPU acceleration (activated by compile(..., device="cuda"))

For now, torch GPU acceleration is not fully supported in lora. The bottleneck being FHE rather than CPU cleartext computation. That being said, there might be cases where it's useful. Also it's pretty easy to support if ever you are interested to open a PR on this.

About the FHE GPU acceleration, the lora fine tuning use case is about learning private weights with the base model parameters on a third-party server. To do this, we only need to do linear layers remotely using FHE and these parts is not yet accelerated on GPU. GPU acceleration is more useful for end-to-end FHE computations with non-linear parts for now.