Open filippo-merlo opened 1 month ago
Hi @filippo-merlo,
The issue you have comes from a misunderstanding of two separate GPU modes:
tensor.to("cuda")
)compile(..., device="cuda")
)For now, torch GPU acceleration is not fully supported in lora. The bottleneck being FHE rather than CPU cleartext computation. That being said, there might be cases where it's useful. Also it's pretty easy to support if ever you are interested to open a PR on this.
About the FHE GPU acceleration, the lora fine tuning use case is about learning private weights with the base model parameters on a third-party server. To do this, we only need to do linear layers remotely using FHE and these parts is not yet accelerated on GPU. GPU acceleration is more useful for end-to-end FHE computations with non-linear parts for now.
When I try to use
compile_model
with CUDA as the specified device, I encounter the following error. Is there a way to resolve this, or is thelora.py
code not yet compatible with running on a GPU?"the tutorial I am following: https://github.com/zama-ai/concrete-ml/tree/release/1.7.x/use_case_examples/lora_finetuning
Traceback (most recent call last): File "/lora_finetuning/lorafinetunegpt2.py", line 135, in
hybrid_model.compile_model(inputset, n_bits=16, device="cuda")
File "/src/concrete/ml/torch/hybrid_model.py", line 516, in compile_model
self.private_q_modules[name] = compile_torch_model(
^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 342, in compile_torch_model
return _compile_torch_or_onnx_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 224, in _compile_torch_or_onnx_model
quantized_module = build_quantized_module(
^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 124, in build_quantized_module
numpy_model = NumpyModule(model, dummy_input_for_tracing)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/numpy_module.py", line 51, in init
) = get_equivalent_numpy_forward_from_torch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/onnx/convert.py", line 153, in get_equivalent_numpy_forward_from_torch
torch.onnx.export(
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1612, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph
outs = ONNXTracedModule(
^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 138, in forward
graph, out = torch._C._create_graph_by_tracing(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 129, in wrapper
outs.append(self.inner(trace_inputs))
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
This is the only part of the code I modified:
Create the HybridFHEModel with the specified remote modules
hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)
Prepare input data for calibration
input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) ( tokenizer.vocab_size - 1 ) label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) ( tokenizer.vocab_size - 1 )
input_tensor = input_tensor.to("cuda") label_tensor = label_tensor.to("cuda")
inputset = (input_tensor, label_tensor)
Calibrate and compile the model
hybrid_model.model.toggle_calibrate(enable=True) hybrid_model.compile_model(inputset, n_bits=16, device="cuda")