BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
624 stars 40 forks source link

Slow inference performance for large Llama models compared to naive MP #66

Open sgsdxzy opened 1 year ago

sgsdxzy commented 1 year ago

The inference speed of naive model parallel is much better than tensor parallel:

Setup: Llama-30b on 2080Ti 22G x4 Naive: 31.64s 4-way TP, main branch: 177.78s 4-way TP, llama branch: 102.22s

The code for naive inference

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.half, device_map="balanced")

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

The code for TP:

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)).half()
    model = tensor_parallel.TensorParallelPreTrainedModel(model)

device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                #    the target devices for each weight using this helper function
# Get nums parts
with open(f"{model_name}/pytorch_model.bin.index.json", "r") as index_file:
    shard_filenames = set(json.load(index_file)["weight_map"].values())

for shard_filename in sorted(shard_filenames):
    # Download a shard
    shard_path = f"{model_name}/{shard_filename}"
    print(shard_path)

    # Convert model shard
    converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
        torch.load(shard_path),                   #    Creates a tensor_parallel checkpoint form a normal one
        model.tensor_parallel_config,
        world_size=4,
        for_pretrained=True,
    )    
    torch.save(converted_state_dict, "/tmp/shard.bin")
    del converted_state_dict

    # Dispatch the shard
    accelerate.load_checkpoint_in_model(
        model,
        checkpoint="/tmp/shard.bin",
        device_map=device_map,
    )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))
BlackSamorez commented 1 year ago

Hi @sgsdxzy! I tried to reproduce this issue in a T4 x2 Kaggle notebook (sadly I don't own 2080Ti 22G x4) and here's what I got:

Which is not quite double the speed but it gets better on larger batches.

About your case: if you're sure that those numbers are valid, maybe It's somehow connected to the fact that you're using 4 cards. What's the data bandwidth between them? Are all 4 cards using enough PCI-E lanes? In this case tensor_parallel is using raw from torch.cuda.nccl communication primitives so it's weird that they are that slow.

sgsdxzy commented 1 year ago

@BlackSamorez I can confirm using 2 cards TP provides a small speedup against 2 cards MP. The 4 cards are all running at pcie3.0x16 on an X99. Here's my P2P connectivity test (I have two nvlinks between [0,1] and [2,3])

P2P Connectivity Matrix                                                                                                                                                [7/32]
     D\D     0     1     2     3
     0       1     1     0     0
     1       1     1     0     0
     2       0     0     1     1
     3       0     0     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 541.72   5.76   5.85   5.87
     1   5.76 542.96   5.82   5.87
     2   5.95   5.94 537.09   5.79
     3   5.89   5.93   5.81 533.16
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3
     0 531.46  47.09   6.00   5.95
     1  47.11 536.05   5.97   5.95
     2   5.87   5.96 532.47  47.09
     3   5.92   5.90  47.10 532.53
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.29   6.11   8.62   8.59
     1   6.12 535.29   8.58   8.57
     2   8.60   8.52 534.05   6.12
     3   8.56   8.57   6.10 534.13
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.55  94.10   8.61   8.59
     1  94.13 534.78   8.56   8.59
     2   8.55   8.60 534.17  94.15
     3   8.62   8.59  94.16 533.62
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3
     0   1.34  12.44  12.30  12.44
     1  12.44   1.38  21.21  12.68
     2  12.53  12.61   1.33  12.44
     3  12.38  12.30  12.68   1.33

   CPU     0      1      2      3
     0   2.05   5.85   5.74   5.82
     1   5.82   1.95   5.80   5.77
     2   5.63   5.66   1.99   5.58
     3   5.75   5.72   5.67   1.97
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3
     0   1.33   1.88  12.30  12.45
     1   1.88   1.38  21.18  12.54
     2  12.53  12.53   1.33   1.85
     3  12.38  21.12   1.85   1.33

   CPU     0      1      2      3
     0   2.02   1.63   5.85   5.91
     1   1.64   1.99   5.75   5.91
     2   5.71   5.69   1.99   1.64
     3   6.01   5.80   1.74   2.12

I think Kaggle T4s are not using nvlinks so that's not the problem here, and I don't think 4-cards would suddenly hit a communication bottleneck and drastically reduce performance. I think it's more of a misconfigure or bug. Where would you suggest me to look?

BlackSamorez commented 1 year ago

@sgsdxzy Thanks! Could you verify that correct communication functions are being used? You should be hitting:

during forward passes.

Also could you please benchmark tensor_parallel on ["cuda:0", "cuda:1"] (nvlink) and ["cuda:0", "cuda:2"] (no nvlink)?

sgsdxzy commented 1 year ago
@BlackSamorez Here's the results: Model setup llama-7b 1gpu llama-7b 8bit 1gpu llama-7b 2gpu+nvlink llama-7b 8bit 2gpu+nvlink llama-7b 2gpu w/o nvlink llama-7b 8bit 2gpu w/o nvlink
Naive time (s) 10.44 37.42 11.45 37.99 12.38 38.92
Naive memory per gpu (GB) 14 8.3 7.7 4.7 7.7 4.7
TP time (s) - - 27.85 28.23 27.66 27.66
TP memory per gpu (GB) - - 7.7 7.7 7.7 7.7

So the problem here:

  1. TP only provides a speed gain for 8bit, and drastically worse for fp16. And the fp16/int8 time for TP is the same, which is also suspicious.
  2. loading in 8bit is not saving VRAM for TP, which can be considered another bug.
  3. nvlink does not affect the result.
  4. I am using the main branch, as llama branch gives me the following error in 8bit (works fine for fp16, reducing 28s to 17s)
Traceback (most recent call last):
  File "/home/sgsdxzy/Programs/text-generation-webui/tp_test.py", line 68, in <module>
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/pretrained_model.py", line 88, in forward
    return self.wrapped_model(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/tensor_parallel.py", line 130, in forward
    return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
    outputs = self.model(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
    layer_outputs = decoder_layer(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/slicer_wrapper.py", line 390, in forward
    output = self.tp_wrapped_module(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 196, in forward
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 320, in forward
    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
  File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1698, in transform
    prev_device = pre_call(A.device)
AttributeError: 'NoneType' object has no attribute 'device'

The updated script for reference

import torch
import time
import argparse
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, LlamaTokenizer
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str)
parser.add_argument('--int8', action='store_true')
parser.add_argument('--mp', type=int)
args = parser.parse_args()

tokenizer = LlamaTokenizer.from_pretrained(args.model)

if args.mp <= 1:
    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.half, load_in_8bit=args.int8, device_map="balanced")
else:
    with accelerate.init_empty_weights():
        model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model)).half()
        model = tensor_parallel.TensorParallelPreTrainedModel(model)
        if args.int8:
            model = replace_8bit_linear(model)
            model.is_loaded_in_8bit = True

    device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                                 #    the target devices for each weight using this helper function

    # Get nums parts
    with open(f"{args.model}/pytorch_model.bin.index.json", "r") as index_file:
        shard_filenames = set(json.load(index_file)["weight_map"].values())

    for shard_filename in sorted(shard_filenames):
        # Download a shard
        shard_path = f"{args.model}/{shard_filename}"
        print(shard_path)

        # Convert model shard
        converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
            torch.load(shard_path),                                #    Creates a tensor_parallel checkpoint form a normal one
            model.tensor_parallel_config,
            world_size=args.mp,
            for_pretrained=True,
        )    
        torch.save(converted_state_dict, "/tmp/shard.bin")
        del converted_state_dict

        # Dispatch the shard
        accelerate.load_checkpoint_in_model(
            model,
            checkpoint="/tmp/shard.bin",
            device_map=device_map,
        )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))
sgsdxzy commented 1 year ago
@BlackSamorez here's results for OPT-6.7B, almost same as llama-7b. Model setup OPT-6.7B 1gpu OPT-6.7B 8bit 1gpu OPT-6.7B 2gpu+nvlink OPT-6.7B 8bit 2gpu+nvlink
Naive time (s) 10.16 39.86 9.94 40.08
Naive memory per gpu (GB) 13.6 7.6 7.6 4.6
TP time (s) - - 23.64 23.81
TP memory per gpu (GB) - - 7.6 7.6

Are you testing in int8 or fp16? Can you get any other cards than dual T4? And I don't think I am having a gpu communication problem as deepspeed-inference provided TP is boosting performance for me on OPT(llama is not well-supported yet), 2-card fp16 is 65% faster than 1-card fp16 https://github.com/oobabooga/text-generation-webui/issues/561#issuecomment-1484933375

sgsdxzy commented 1 year ago

@sgsdxzy Thanks! Could you verify that correct communication functions are being used? You should be hitting:

* https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L95

* https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L77

during forward passes.

Also could you please benchmark tensor_parallel on ["cuda:0", "cuda:1"] (nvlink) and ["cuda:0", "cuda:2"] (no nvlink)?

I find NCCLAllGatherFunction is called, but not NCCLAllReduceFunction

BlackSamorez commented 1 year ago

@sgsdxzy Hi! Firstly, about int8. You need the latest accelerate (like main branch from GitHub) to dispatch int8 models with load_checkpoint_in_model. Otherwise int8 layers are not quantized and behave exactly like fp16. About everything else: I'll need some time to test it. It could be due a lot of reasons including bugs in communications or tensor cores suddenly not kicking-in for tensor_parallel.

sgsdxzy commented 1 year ago

@BlackSamorez I upragded accelerate to git+https://github.com/huggingface/accelerate , however the VRAM usage and speed is the same.

BlackSamorez commented 1 year ago

@sgsdxzy Now that's weird. This demo works which means that int8 should work fine since those model won't physically fit in VRAM in fp16. Could you please attach the result of pip freeze in your environment.

sgsdxzy commented 1 year ago

@BlackSamorez it's here. This is conda envrionment, tell me if you suspect any specific package that doesn't have version listed by pip freeze

accelerate @ git+https://github.com/huggingface/accelerate@b757b6232516da4ece0fbcfec66855b37523f39a
aiofiles @ file:///home/conda/feedstock_root/build_artifacts/aiofiles_1664378549280/work
aiohttp==3.8.4
aiosignal==1.3.1
aiosqlite @ file:///home/conda/feedstock_root/build_artifacts/aiosqlite_1671461885930/work
altair==4.2.2
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1666850768662/work
astroid @ file:///home/conda/feedstock_root/build_artifacts/astroid_1679923748219/work
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
async-timeout==4.0.2
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
autopep8 @ file:///home/conda/feedstock_root/build_artifacts/autopep8_1635267974115/work
Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1679322162244/work
bitsandbytes==0.37.2
black @ file:///home/conda/feedstock_root/build_artifacts/black-recipe_1675252854302/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1666764671472/work
certifi==2022.12.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1671179353105/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work
cmake==3.26.1
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1679481329611/work
contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1673633665736/work
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1679811212387/work
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work
Cython @ file:///home/conda/feedstock_root/build_artifacts/cython_1673054058583/work
daal4py==2023.0.2
datasets==2.11.0
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1674522362098/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
deepspeed==0.8.3
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1666603105584/work
docstring-to-markdown @ file:///home/conda/feedstock_root/build_artifacts/docstring-to-markdown_1679424273982/work
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
fastapi==0.95.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1677336799617/work/dist
ffmpy==0.3.0
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1679932713187/work
fire==0.5.0
flake8 @ file:///home/conda/feedstock_root/build_artifacts/flake8_1669396691980/work
flexgen==0.1.7
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1680021152278/work
frozenlist==1.3.3
fsspec==2023.3.0
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1666808654411/work
gradio==3.24.1
gradio_client==0.0.5
h11==0.14.0
hjson==3.1.0
httpcore==0.16.3
httpx==0.23.3
huggingface-hub==0.13.3
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1679167925176/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1679336319192/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1677617093347/work
ipython-genutils==0.2.0
isort @ file:///home/conda/feedstock_root/build_artifacts/isort_1675033873689/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1663332044897/work
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1600692310011/work
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work
jupyter-ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_ydoc_1679325289144/work/dist
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1679365123476/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1678994169527/work
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1679073341944/work
jupyter_server_fileid @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_fileid_1677220209229/work
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work
jupyter_server_ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_ydoc_1678043727957/work
jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1679327603632/work
jupyterlab-code-formatter @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_code_formatter_1679847042826/work
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1679528718717/work
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1666805701884/work
lazy-object-proxy @ file:///home/conda/feedstock_root/build_artifacts/lazy-object-proxy_1672877787898/work
linkify-it-py==2.0.0
lit==16.0.0
loralib==0.1.1
Markdown==3.4.3
markdown-it-py==2.2.0
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1674135787083/work
matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1678135565516/work
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mccabe @ file:///home/conda/feedstock_root/build_artifacts/mccabe_1643049622439/work
mdit-py-plugins==0.3.3
mdurl==0.1.2
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
multidict==6.0.4
multiprocess==0.70.14
munkres==1.1.4
mypy-extensions @ file:///home/conda/feedstock_root/build_artifacts/mypy_extensions_1675543315189/work
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1678277563913/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1680034059411/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1679336765223/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1673151334029/work
ninja==1.11.1
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1678109761260/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1675642512762/work
orjson==3.8.9
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1678853982175/work
peft @ git+https://github.com/huggingface/peft.git@445940fb7b5d38390ffb6707e2a989e89fff03b5
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1675487172403/work
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1679871349196/work
pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1667232663820/work
ply==3.11
pooch @ file:///home/conda/feedstock_root/build_artifacts/pooch_1679580333621/work
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1667885877572/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
PuLP==2.7.0
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py-cpuinfo==9.0.0
pyarrow==11.0.0
pybind11==2.10.4
pycodestyle @ file:///home/conda/feedstock_root/build_artifacts/pycodestyle_1669306857274/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.7
pydocstyle @ file:///home/conda/feedstock_root/build_artifacts/pydocstyle_1673997095229/work
pydub==0.25.1
pyflakes @ file:///home/conda/feedstock_root/build_artifacts/pyflakes_1669319921641/work
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
pylint @ file:///home/conda/feedstock_root/build_artifacts/pylint_1679515272965/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1680037383858/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
PyQt5==5.15.7
PyQt5-sip==12.11.0
pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1672681463845/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
python-lsp-jsonrpc @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-jsonrpc_1618530352985/work
python-lsp-server @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-server-meta_1674005136083/work
python-multipart==0.0.6
pytoolconfig @ file:///home/conda/feedstock_root/build_artifacts/pytoolconfig_1675124745143/work
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1666772395347/work
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1679316826707/work
regex==2023.3.23
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
responses==0.18.0
rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
rfc3986==1.5.0
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
rope @ file:///home/conda/feedstock_root/build_artifacts/rope_1674988456931/work
rwkv==0.7.3
safetensors==0.3.0
scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1679675836718/work
scikit-learn-intelex==20230131.200059
scipy==1.10.1
semantic-version==2.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
sentencepiece==0.1.97
sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1675696581052/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
snowballstemmer @ file:///home/conda/feedstock_root/build_artifacts/snowballstemmer_1637143057757/work
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
starlette==0.26.1
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1679342590084/work
tensor-parallel @ file:///home/sgsdxzy/Programs/tensor_parallel
termcolor==2.2.0
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1643647933166/work
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tokenize-rt==5.0.0
tokenizers==0.13.3
toml @ file:///home/conda/feedstock_root/build_artifacts/toml_1604308577558/work
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_1679924068997/work
toolz==0.12.0
torch==2.0.0
torchaudio==2.0.0
torchvision==0.15.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1666788589303/work
tqdm==4.65.0
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
transformers @ git+https://github.com/huggingface/transformers.git@ee8e80a060d65ab349743ffcb5842365eb0e5606
triton==2.0.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1678559861143/work
uc-micro-py==1.0.1
ujson @ file:///home/conda/feedstock_root/build_artifacts/ujson_1675191915931/work
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1667239886688/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1678635778344/work
uvicorn==0.21.1
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
websockets==10.4
whatthepatch @ file:///home/conda/feedstock_root/build_artifacts/whatthepatch_1675090462655/work
wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1677485519705/work
xxhash==3.2.0
y-py @ file:///home/conda/feedstock_root/build_artifacts/y-py_1677231008299/work
yapf @ file:///home/conda/feedstock_root/build_artifacts/yapf_1641487982943/work
yarl==1.8.2
ypy-websocket @ file:///home/conda/feedstock_root/build_artifacts/ypy-websocket_1670333059911/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1677313463193/work
BlackSamorez commented 1 year ago

@sgsdxzy By the way here's what I get on my setup with decapoda-research/llama-7b-hf:

Only RTX 3060 x 2 speeds things up. Something's definitely very wrong.

BlackSamorez commented 1 year ago

I've tested pure forward passes and it looks good:

On the same GTX 1080 x 4. Maybe something's wrong with past_key_values processing which makes generation slow. Will look into it.

sgsdxzy commented 1 year ago

@BlackSamorez is that past_key_values are gathered to cuda:0 and redistributed to each rank every time?

BlackSamorez commented 1 year ago

@BlackSamorez is that past_key_values are gathered to cuda:0 and redistributed to each rank every time?

I'm not sure. There is a different data structure for ungathered tensors called PerDeviceTensors and it's used for past_key_values. They should not be gathered at all. I'll need to verify that it's working as expected.

sgsdxzy commented 1 year ago

Have you identified the issue? With 1.2.1, load_in_8bit actually saves VRAM for me, but the performance is still bad.

cxxz commented 1 year ago

I also observed slowdown with tensor_parallel 1.2.1 compared to native performance on single GPU.

Setup

Llama-7b on 8 x A100 80GB (NVLink)

Prompt

"Count up from 100 to 130"

so the number of new generated tokens is a fixed value (155)

Inference Performance

1-GPU w/o TP: inference time 7.08s, GPU-util by nvidia-smi about 69% 2-way TP: inference time 10.24s, GPU-util by nvidia-smi only about 23% the only code difference between the two tests are,

### 1-GPU w/o TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16, device_map="sequential")

vs.

### 2-way TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
model = TensorParallelPreTrainedModel(model, ["cuda:0", "cuda:1"])

any hints on what might have gone wrong?

BlackSamorez commented 1 year ago

I've measured the performance of LLaMA 13B on Kaggle 2x T4 and here's what I got:

Forward passes

image

Generation

image

It's definitely a .generate() problem. I'll look into it and, hopefully, release a fix soon.

cxxz commented 1 year ago

Thank you for sharing your findings on the performance of LLaMA 13B on Kaggle 2x T4. Good to know that you've identified the .generate() issue. I appreciate your efforts in looking into it and eagerly await the release of a fix. Keep up the good work!

dmgcsilva commented 1 year ago

Hi @BlackSamorez , have you been able to identify and fix the issue? I am having similar issues, where using 2 way or even 4 way tp slows down inference times, while using 2xA100 40GB w/ NVLINK

eric-mitchell commented 1 year ago

Would love to know if there is any update on this issue @BlackSamorez. tensor_parallel works great for us for training (nice job!), but the inability to actually sample from the model is a dealbreaker for us. We're seeing slow generation for non-llama models too (e.g., Pythia-6.9b).

BlackSamorez commented 1 year ago

@eric-mitchell @dmgcsilva Sadly, I have no time nor resources to properly test and benchmark this right now. I'll do it in a month or so.

152334H commented 1 year ago

anyone find an alternative efficient TP solution yet?

chujiezheng commented 11 months ago

Also found that 4gpus tp is much slower than 2gpus tp, while the latter is still a bit faster than 2*gpus pp.

dutsc commented 8 months ago

This work is very meaningful. I followed @sgsdxzy and conducted the following test on 3090.

Model setup opt-6.7b 1gpu opt-6.7b 2gpu opt-1.3b1gpu opt-1.3b2gpu opt-13b 4gpu
Naive per token time (ms) 21.5 21.5(singal card) 12.5 12.5(singal card)  52.11
Naive memory per gpu (GB) 12.8 12.8 2.9 2.9 -
TP time (ms) - 76.89 - 62.1 373.71
TP memory per gpu (GB) - 6.5 - 1.6 6.7GB

But performance seems to be the same. Are there any other useful tensor parallel tools?

sgsdxzy commented 8 months ago

@dutsc I use Aphrodite-engine or vLLM for TP inference.

dutsc commented 8 months ago

Thank you for your answer.