tunib-ai / parallelformers

Parallelformers: An Efficient Model Parallelization Toolkit for Deployment
https://tunib-ai.github.io/parallelformers
Apache License 2.0
776 stars 61 forks source link

INT8 support #39

Open volkerha opened 1 year ago

volkerha commented 1 year ago

Describe a requested feature

I wonder if there's any plan to support 8bit inference in parallelformers. Right now, we can load 🤗 transformers models in 8bit like here, e.g.:

model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)

However, it's not possible to parallelize() the model with parallelformers since only fp16 mode is supported at the moment.

Expected behavior

If 8bit inference could be supported, it would good to add another argument as for fp16, e.g.

from parallelformers import parallelize

model = AutoModelForCausalLM.from_pretrained(model_name)
parallelize(model, num_gpus=2, int8=True, verbose='detail')
# or one argument for precision mode, where dtype can be either "int8", "fp16", or "fp32" (default)
# parallelize(model, num_gpus=2, dtype='int8', verbose='detail')