michaelfeil / infinity

Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models, clip, clap and colpali
https://michaelfeil.github.io/infinity/
MIT License
1.4k stars 112 forks source link

Permit loading of models at different precision at load time for sentence_transformers #331

Open dawu415 opened 2 months ago

dawu415 commented 2 months ago

Feature request

Pass in torch_dtype in model_kwargs, as supported by sentence_transformers when specifying dtype in the infinity_emb v2 cli when InferenceEngine type is torch.

This would allow the loading of the Transformer model at a lower precision at load time instead of post-loading, which could cause an OOM error.

Post loading quantization, e.g. self.half(), would still be needed since it appears the non-transformer Pytorch modules in some models are still in fp32 and would cause issues downstream where matrix computations later would fail due to mixed type computation.

Motivation

While the current code quantizes the precision post loading a model, the issue is that if we're loading full 32-bit float models against a GPU that has a small amount of memory, it could fail to load since the total size of the model would exceed the GPUs maximum memory limit. A use case of this would be seen in small multi-instance GPUs e.g. deploying a 10Gb GPU instance in an NVIDIA A100 in MIG mode.

By specifying the precision at load time, we would be able to load a model without OOM errors and successfully use the model.

Your contribution

Yes. Happy to submit a PR.

Current tests on my end utilize the following code modifications in the init function of the SentenceTransformerPatched class.

        if engine_args.engine == InferenceEngine.torch and \
           engine_args.device in [Device.auto, Device.cuda] and \
           engine_args.dtype == Dtype.float16:
            model_kwargs["torch_dtype"] = torch.float16

but would need further work for better support/handling of other types. Some thoughts on what can be done for each dtype per the torch attributes here: https://pytorch.org/docs/stable/tensor_attributes.html

auto -> (leave blank) float32 -> torch.float float16 -> torch.half float8 -> (leave blank) ? int8 -> (leave blank) ?

The 8-bit options were left blank since I noticed that quantization is performed later for these types - unsure if there is a better approach here.

greptile-apps[bot] commented 2 months ago

To permit loading of models at different precision at load time for sentence_transformers, follow these steps:

  1. Modify the SentenceTransformerPatched class initialization to include torch_dtype in model_kwargs based on engine_args.dtype.
if engine_args.engine == InferenceEngine.torch and 
   engine_args.device in [Device.auto, Device.cuda]:
    if engine_args.dtype == Dtype.float16:
        model_kwargs["torch_dtype"] = torch.float16
    elif engine_args.dtype == Dtype.float32:
        model_kwargs["torch_dtype"] = torch.float
    # Add other dtype handling as needed
  1. Ensure that the model_kwargs is passed correctly when initializing the model.

Refer to the PyTorch tensor attributes documentation for handling other dtypes: https://pytorch.org/docs/stable/tensor_attributes.html

References

/.github/ISSUE_TEMPLATE/new-model-addition.yml /libs/infinity_emb/tests/end_to_end/test_optimum_embedding.py /docs/docs/index.md /docs/assets/openapi.json /libs/infinity_emb/tests/conftest.py /README.md /docs/docs/deploy.md /.github/ISSUE_TEMPLATE/bug-report.yml /.github/ISSUE_TEMPLATE /libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py /libs/infinity_emb/infinity_emb/transformer/quantization/interface.py /docs/docs /libs/infinity_emb /libs/infinity_emb/infinity_emb/transformer/embedder/dummytransformer.py /libs/infinity_emb/infinity_emb/inference/select_model.py /docs/benchmarks/simple_app.py /libs/infinity_emb/tests/unit_test/inference/test_select_model.py /libs/infinity_emb/tests/unit_test/transformer/quantization /libs/infinity_emb/infinity_emb/fastapi_schemas /docs/docs/python_engine.md /libs/infinity_emb/infinity_emb/transformer/classifier /libs/infinity_emb/Makefile /libs/infinity_emb/tests/end_to_end/test_torch_classify.py /docs

#### About Greptile This response provides a starting point for your research, not a precise solution. Help us improve! Please leave a 👍 if this is helpful and 👎 if it is irrelevant. [Ask Greptile](https://app.greptile.com/chat/github/michaelfeil/infinity/main) · [Edit Issue Bot Settings](https://app.greptile.com/apps/github)
michaelfeil commented 2 months ago

So your idea is to safe memory at loading time before performing model.half()? Potentially a method that could be combined with device. Need to wrap my head around, maybe add a function at Dtype