Open dawu415 opened 2 months ago
To permit loading of models at different precision at load time for sentence_transformers
, follow these steps:
SentenceTransformerPatched
class initialization to include torch_dtype
in model_kwargs
based on engine_args.dtype
.if engine_args.engine == InferenceEngine.torch and
engine_args.device in [Device.auto, Device.cuda]:
if engine_args.dtype == Dtype.float16:
model_kwargs["torch_dtype"] = torch.float16
elif engine_args.dtype == Dtype.float32:
model_kwargs["torch_dtype"] = torch.float
# Add other dtype handling as needed
model_kwargs
is passed correctly when initializing the model.Refer to the PyTorch tensor attributes documentation for handling other dtypes: https://pytorch.org/docs/stable/tensor_attributes.html
/.github/ISSUE_TEMPLATE/new-model-addition.yml /libs/infinity_emb/tests/end_to_end/test_optimum_embedding.py /docs/docs/index.md /docs/assets/openapi.json /libs/infinity_emb/tests/conftest.py /README.md /docs/docs/deploy.md /.github/ISSUE_TEMPLATE/bug-report.yml /.github/ISSUE_TEMPLATE /libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py /libs/infinity_emb/infinity_emb/transformer/quantization/interface.py /docs/docs /libs/infinity_emb /libs/infinity_emb/infinity_emb/transformer/embedder/dummytransformer.py /libs/infinity_emb/infinity_emb/inference/select_model.py /docs/benchmarks/simple_app.py /libs/infinity_emb/tests/unit_test/inference/test_select_model.py /libs/infinity_emb/tests/unit_test/transformer/quantization /libs/infinity_emb/infinity_emb/fastapi_schemas /docs/docs/python_engine.md /libs/infinity_emb/infinity_emb/transformer/classifier /libs/infinity_emb/Makefile /libs/infinity_emb/tests/end_to_end/test_torch_classify.py /docs
So your idea is to safe memory at loading time before performing model.half()
? Potentially a method that could be combined with device. Need to wrap my head around, maybe add a function at Dtype
Feature request
Pass in
torch_dtype
in model_kwargs, as supported by sentence_transformers when specifying dtype in the infinity_emb v2 cli when InferenceEngine type is torch.This would allow the loading of the Transformer model at a lower precision at load time instead of post-loading, which could cause an OOM error.
Post loading quantization, e.g. self.half(), would still be needed since it appears the non-transformer Pytorch modules in some models are still in fp32 and would cause issues downstream where matrix computations later would fail due to mixed type computation.
Motivation
While the current code quantizes the precision post loading a model, the issue is that if we're loading full 32-bit float models against a GPU that has a small amount of memory, it could fail to load since the total size of the model would exceed the GPUs maximum memory limit. A use case of this would be seen in small multi-instance GPUs e.g. deploying a 10Gb GPU instance in an NVIDIA A100 in MIG mode.
By specifying the precision at load time, we would be able to load a model without OOM errors and successfully use the model.
Your contribution
Yes. Happy to submit a PR.
Current tests on my end utilize the following code modifications in the init function of the SentenceTransformerPatched class.
but would need further work for better support/handling of other types. Some thoughts on what can be done for each dtype per the torch attributes here: https://pytorch.org/docs/stable/tensor_attributes.html
auto -> (leave blank) float32 -> torch.float float16 -> torch.half float8 -> (leave blank) ? int8 -> (leave blank) ?
The 8-bit options were left blank since I noticed that quantization is performed later for these types - unsure if there is a better approach here.