Open sgsdxzy opened 1 year ago
Hi @sgsdxzy!
I tried to reproduce this issue in a T4 x2
Kaggle notebook (sadly I don't own 2080Ti 22G x4
) and here's what I got:
Which is not quite double the speed but it gets better on larger batches.
About your case: if you're sure that those numbers are valid, maybe It's somehow connected to the fact that you're using 4 cards. What's the data bandwidth between them? Are all 4 cards using enough PCI-E lanes?
In this case tensor_parallel
is using raw from torch.cuda.nccl
communication primitives so it's weird that they are that slow.
@BlackSamorez I can confirm using 2 cards TP provides a small speedup against 2 cards MP. The 4 cards are all running at pcie3.0x16 on an X99. Here's my P2P connectivity test (I have two nvlinks between [0,1] and [2,3])
P2P Connectivity Matrix [7/32]
D\D 0 1 2 3
0 1 1 0 0
1 1 1 0 0
2 0 0 1 1
3 0 0 1 1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
D\D 0 1 2 3
0 541.72 5.76 5.85 5.87
1 5.76 542.96 5.82 5.87
2 5.95 5.94 537.09 5.79
3 5.89 5.93 5.81 533.16
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
D\D 0 1 2 3
0 531.46 47.09 6.00 5.95
1 47.11 536.05 5.97 5.95
2 5.87 5.96 532.47 47.09
3 5.92 5.90 47.10 532.53
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
D\D 0 1 2 3
0 533.29 6.11 8.62 8.59
1 6.12 535.29 8.58 8.57
2 8.60 8.52 534.05 6.12
3 8.56 8.57 6.10 534.13
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
D\D 0 1 2 3
0 533.55 94.10 8.61 8.59
1 94.13 534.78 8.56 8.59
2 8.55 8.60 534.17 94.15
3 8.62 8.59 94.16 533.62
P2P=Disabled Latency Matrix (us)
GPU 0 1 2 3
0 1.34 12.44 12.30 12.44
1 12.44 1.38 21.21 12.68
2 12.53 12.61 1.33 12.44
3 12.38 12.30 12.68 1.33
CPU 0 1 2 3
0 2.05 5.85 5.74 5.82
1 5.82 1.95 5.80 5.77
2 5.63 5.66 1.99 5.58
3 5.75 5.72 5.67 1.97
P2P=Enabled Latency (P2P Writes) Matrix (us)
GPU 0 1 2 3
0 1.33 1.88 12.30 12.45
1 1.88 1.38 21.18 12.54
2 12.53 12.53 1.33 1.85
3 12.38 21.12 1.85 1.33
CPU 0 1 2 3
0 2.02 1.63 5.85 5.91
1 1.64 1.99 5.75 5.91
2 5.71 5.69 1.99 1.64
3 6.01 5.80 1.74 2.12
I think Kaggle T4s are not using nvlinks so that's not the problem here, and I don't think 4-cards would suddenly hit a communication bottleneck and drastically reduce performance. I think it's more of a misconfigure or bug. Where would you suggest me to look?
@sgsdxzy Thanks! Could you verify that correct communication functions are being used? You should be hitting:
during forward passes.
Also could you please benchmark tensor_parallel
on ["cuda:0", "cuda:1"]
(nvlink) and ["cuda:0", "cuda:2"]
(no nvlink)?
@BlackSamorez Here's the results: | Model setup | llama-7b 1gpu | llama-7b 8bit 1gpu | llama-7b 2gpu+nvlink | llama-7b 8bit 2gpu+nvlink | llama-7b 2gpu w/o nvlink | llama-7b 8bit 2gpu w/o nvlink |
---|---|---|---|---|---|---|---|
Naive time (s) | 10.44 | 37.42 | 11.45 | 37.99 | 12.38 | 38.92 | |
Naive memory per gpu (GB) | 14 | 8.3 | 7.7 | 4.7 | 7.7 | 4.7 | |
TP time (s) | - | - | 27.85 | 28.23 | 27.66 | 27.66 | |
TP memory per gpu (GB) | - | - | 7.7 | 7.7 | 7.7 | 7.7 |
So the problem here:
Traceback (most recent call last):
File "/home/sgsdxzy/Programs/text-generation-webui/tp_test.py", line 68, in <module>
generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
return self.greedy_search(
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
outputs = self(
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/pretrained_model.py", line 88, in forward
return self.wrapped_model(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/tensor_parallel.py", line 130, in forward
return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
raise exception
AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
outputs = self.model(
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
layer_outputs = decoder_layer(
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/tensor_parallel/slicer_wrapper.py", line 390, in forward
output = self.tp_wrapped_module(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 196, in forward
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 320, in forward
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
File "/home/sgsdxzy/mambaforge/envs/textgen/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1698, in transform
prev_device = pre_call(A.device)
AttributeError: 'NoneType' object has no attribute 'device'
The updated script for reference
import torch
import time
import argparse
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, LlamaTokenizer
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str)
parser.add_argument('--int8', action='store_true')
parser.add_argument('--mp', type=int)
args = parser.parse_args()
tokenizer = LlamaTokenizer.from_pretrained(args.model)
if args.mp <= 1:
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.half, load_in_8bit=args.int8, device_map="balanced")
else:
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model)).half()
model = tensor_parallel.TensorParallelPreTrainedModel(model)
if args.int8:
model = replace_8bit_linear(model)
model.is_loaded_in_8bit = True
device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
# the target devices for each weight using this helper function
# Get nums parts
with open(f"{args.model}/pytorch_model.bin.index.json", "r") as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
for shard_filename in sorted(shard_filenames):
# Download a shard
shard_path = f"{args.model}/{shard_filename}"
print(shard_path)
# Convert model shard
converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function.
torch.load(shard_path), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=args.mp,
for_pretrained=True,
)
torch.save(converted_state_dict, "/tmp/shard.bin")
del converted_state_dict
# Dispatch the shard
accelerate.load_checkpoint_in_model(
model,
checkpoint="/tmp/shard.bin",
device_map=device_map,
)
torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
batch = tokenizer(
"DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
return_tensors="pt"
)
batch = {k: v.cuda(0) for k, v in batch.items()}
print("Start")
t0 = time.time()
generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds")
print(tokenizer.decode(generated[0]))
@BlackSamorez here's results for OPT-6.7B, almost same as llama-7b. | Model setup | OPT-6.7B 1gpu | OPT-6.7B 8bit 1gpu | OPT-6.7B 2gpu+nvlink | OPT-6.7B 8bit 2gpu+nvlink |
---|---|---|---|---|---|
Naive time (s) | 10.16 | 39.86 | 9.94 | 40.08 | |
Naive memory per gpu (GB) | 13.6 | 7.6 | 7.6 | 4.6 | |
TP time (s) | - | - | 23.64 | 23.81 | |
TP memory per gpu (GB) | - | - | 7.6 | 7.6 |
Are you testing in int8 or fp16? Can you get any other cards than dual T4? And I don't think I am having a gpu communication problem as deepspeed-inference provided TP is boosting performance for me on OPT(llama is not well-supported yet), 2-card fp16 is 65% faster than 1-card fp16 https://github.com/oobabooga/text-generation-webui/issues/561#issuecomment-1484933375
@sgsdxzy Thanks! Could you verify that correct communication functions are being used? You should be hitting:
* https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L95 * https://github.com/BlackSamorez/tensor_parallel/blob/main/src/tensor_parallel/cross_device_ops.py#L77
during forward passes.
Also could you please benchmark
tensor_parallel
on["cuda:0", "cuda:1"]
(nvlink) and["cuda:0", "cuda:2"]
(no nvlink)?
I find NCCLAllGatherFunction
is called, but not NCCLAllReduceFunction
@sgsdxzy Hi!
Firstly, about int8
. You need the latest accelerate (like main branch from GitHub) to dispatch int8
models with load_checkpoint_in_model
. Otherwise int8
layers are not quantized and behave exactly like fp16
.
About everything else: I'll need some time to test it. It could be due a lot of reasons including bugs in communications or tensor cores suddenly not kicking-in for tensor_parallel
.
@BlackSamorez I upragded accelerate
to git+https://github.com/huggingface/accelerate , however the VRAM usage and speed is the same.
@sgsdxzy Now that's weird. This demo works which means that int8
should work fine since those model won't physically fit in VRAM in fp16
.
Could you please attach the result of pip freeze
in your environment.
@BlackSamorez it's here. This is conda envrionment, tell me if you suspect any specific package that doesn't have version listed by pip freeze
accelerate @ git+https://github.com/huggingface/accelerate@b757b6232516da4ece0fbcfec66855b37523f39a
aiofiles @ file:///home/conda/feedstock_root/build_artifacts/aiofiles_1664378549280/work
aiohttp==3.8.4
aiosignal==1.3.1
aiosqlite @ file:///home/conda/feedstock_root/build_artifacts/aiosqlite_1671461885930/work
altair==4.2.2
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1666850768662/work
astroid @ file:///home/conda/feedstock_root/build_artifacts/astroid_1679923748219/work
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
async-timeout==4.0.2
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
autopep8 @ file:///home/conda/feedstock_root/build_artifacts/autopep8_1635267974115/work
Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1679322162244/work
bitsandbytes==0.37.2
black @ file:///home/conda/feedstock_root/build_artifacts/black-recipe_1675252854302/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1666764671472/work
certifi==2022.12.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1671179353105/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work
cmake==3.26.1
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1679481329611/work
contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1673633665736/work
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1679811212387/work
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work
Cython @ file:///home/conda/feedstock_root/build_artifacts/cython_1673054058583/work
daal4py==2023.0.2
datasets==2.11.0
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1674522362098/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
deepspeed==0.8.3
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1666603105584/work
docstring-to-markdown @ file:///home/conda/feedstock_root/build_artifacts/docstring-to-markdown_1679424273982/work
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
fastapi==0.95.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1677336799617/work/dist
ffmpy==0.3.0
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1679932713187/work
fire==0.5.0
flake8 @ file:///home/conda/feedstock_root/build_artifacts/flake8_1669396691980/work
flexgen==0.1.7
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1680021152278/work
frozenlist==1.3.3
fsspec==2023.3.0
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1666808654411/work
gradio==3.24.1
gradio_client==0.0.5
h11==0.14.0
hjson==3.1.0
httpcore==0.16.3
httpx==0.23.3
huggingface-hub==0.13.3
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1679167925176/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1679336319192/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1677617093347/work
ipython-genutils==0.2.0
isort @ file:///home/conda/feedstock_root/build_artifacts/isort_1675033873689/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1663332044897/work
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1600692310011/work
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work
jupyter-ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_ydoc_1679325289144/work/dist
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1679365123476/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1678994169527/work
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1679073341944/work
jupyter_server_fileid @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_fileid_1677220209229/work
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work
jupyter_server_ydoc @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_ydoc_1678043727957/work
jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1679327603632/work
jupyterlab-code-formatter @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_code_formatter_1679847042826/work
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1679528718717/work
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1666805701884/work
lazy-object-proxy @ file:///home/conda/feedstock_root/build_artifacts/lazy-object-proxy_1672877787898/work
linkify-it-py==2.0.0
lit==16.0.0
loralib==0.1.1
Markdown==3.4.3
markdown-it-py==2.2.0
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1674135787083/work
matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1678135565516/work
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mccabe @ file:///home/conda/feedstock_root/build_artifacts/mccabe_1643049622439/work
mdit-py-plugins==0.3.3
mdurl==0.1.2
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
multidict==6.0.4
multiprocess==0.70.14
munkres==1.1.4
mypy-extensions @ file:///home/conda/feedstock_root/build_artifacts/mypy_extensions_1675543315189/work
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1678277563913/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1680034059411/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1679336765223/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1673151334029/work
ninja==1.11.1
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1678109761260/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1675642512762/work
orjson==3.8.9
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1678853982175/work
peft @ git+https://github.com/huggingface/peft.git@445940fb7b5d38390ffb6707e2a989e89fff03b5
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1675487172403/work
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1679871349196/work
pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1667232663820/work
ply==3.11
pooch @ file:///home/conda/feedstock_root/build_artifacts/pooch_1679580333621/work
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1667885877572/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
PuLP==2.7.0
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py-cpuinfo==9.0.0
pyarrow==11.0.0
pybind11==2.10.4
pycodestyle @ file:///home/conda/feedstock_root/build_artifacts/pycodestyle_1669306857274/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.7
pydocstyle @ file:///home/conda/feedstock_root/build_artifacts/pydocstyle_1673997095229/work
pydub==0.25.1
pyflakes @ file:///home/conda/feedstock_root/build_artifacts/pyflakes_1669319921641/work
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
pylint @ file:///home/conda/feedstock_root/build_artifacts/pylint_1679515272965/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1680037383858/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
PyQt5==5.15.7
PyQt5-sip==12.11.0
pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1672681463845/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
python-lsp-jsonrpc @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-jsonrpc_1618530352985/work
python-lsp-server @ file:///home/conda/feedstock_root/build_artifacts/python-lsp-server-meta_1674005136083/work
python-multipart==0.0.6
pytoolconfig @ file:///home/conda/feedstock_root/build_artifacts/pytoolconfig_1675124745143/work
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1666772395347/work
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1679316826707/work
regex==2023.3.23
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
responses==0.18.0
rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
rfc3986==1.5.0
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
rope @ file:///home/conda/feedstock_root/build_artifacts/rope_1674988456931/work
rwkv==0.7.3
safetensors==0.3.0
scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1679675836718/work
scikit-learn-intelex==20230131.200059
scipy==1.10.1
semantic-version==2.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
sentencepiece==0.1.97
sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1675696581052/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
snowballstemmer @ file:///home/conda/feedstock_root/build_artifacts/snowballstemmer_1637143057757/work
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
starlette==0.26.1
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1679342590084/work
tensor-parallel @ file:///home/sgsdxzy/Programs/tensor_parallel
termcolor==2.2.0
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1643647933166/work
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tokenize-rt==5.0.0
tokenizers==0.13.3
toml @ file:///home/conda/feedstock_root/build_artifacts/toml_1604308577558/work
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_1679924068997/work
toolz==0.12.0
torch==2.0.0
torchaudio==2.0.0
torchvision==0.15.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1666788589303/work
tqdm==4.65.0
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
transformers @ git+https://github.com/huggingface/transformers.git@ee8e80a060d65ab349743ffcb5842365eb0e5606
triton==2.0.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1678559861143/work
uc-micro-py==1.0.1
ujson @ file:///home/conda/feedstock_root/build_artifacts/ujson_1675191915931/work
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1667239886688/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1678635778344/work
uvicorn==0.21.1
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
websockets==10.4
whatthepatch @ file:///home/conda/feedstock_root/build_artifacts/whatthepatch_1675090462655/work
wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1677485519705/work
xxhash==3.2.0
y-py @ file:///home/conda/feedstock_root/build_artifacts/y-py_1677231008299/work
yapf @ file:///home/conda/feedstock_root/build_artifacts/yapf_1641487982943/work
yarl==1.8.2
ypy-websocket @ file:///home/conda/feedstock_root/build_artifacts/ypy-websocket_1670333059911/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1677313463193/work
@sgsdxzy By the way here's what I get on my setup with decapoda-research/llama-7b-hf
:
GTX 1080 x 2
tensor_parallel
: 25.87 secondsGTX 1080 x 2
sequential: 16.11 secondsGTX 1080 x 3
tensor_parallel
: 22.40 secondsGTX 1080 x 3
sequential: 18.14 secondsGTX 1080 x 4
tensor_parallel
: 75.03 secondsGTX 1080 x 4
sequential: 19.41 secondsRTX 3060 x 2
tensor_parallel
: 15.25 secondsRTX 3060 x 2
sequential: 19.90 secondsRTX 3060 x 2 + GTX 1080 x 2
tensor_parallel
: 73.69 secondsRTX 3060 x 2 + GTX 1080 x 2
sequential: 31.91 secondsRTX 3060 x 2 + GTX 1080 x 4
tensor_parallel
: 123.15 secondsRTX 3060 x 2 + GTX 1080 x 4
sequential: 29.55 secondsOnly RTX 3060 x 2
speeds things up. Something's definitely very wrong.
I've tested pure forward passes and it looks good:
Done 10 passes with batch_size=8 lenght=512 in 45.55 seconds
with tensor_parallel
Done 10 passes with batch_size=8 lenght=512 in 194.08 seconds
sequentialOn the same GTX 1080 x 4
. Maybe something's wrong with past_key_values
processing which makes generation slow. Will look into it.
@BlackSamorez is that past_key_values
are gathered to cuda:0
and redistributed to each rank every time?
@BlackSamorez is that
past_key_values
are gathered tocuda:0
and redistributed to each rank every time?
I'm not sure. There is a different data structure for ungathered tensors called PerDeviceTensors
and it's used for past_key_values
. They should not be gathered at all. I'll need to verify that it's working as expected.
Have you identified the issue?
With 1.2.1, load_in_8bit
actually saves VRAM for me, but the performance is still bad.
I also observed slowdown with tensor_parallel 1.2.1 compared to native performance on single GPU.
Llama-7b on 8 x A100 80GB (NVLink)
"Count up from 100 to 130"
so the number of new generated tokens is a fixed value (155)
1-GPU w/o TP: inference time 7.08s, GPU-util by nvidia-smi
about 69%
2-way TP: inference time 10.24s, GPU-util by nvidia-smi
only about 23%
the only code difference between the two tests are,
### 1-GPU w/o TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16, device_map="sequential")
vs.
### 2-way TP
model = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
model = TensorParallelPreTrainedModel(model, ["cuda:0", "cuda:1"])
any hints on what might have gone wrong?
I've measured the performance of LLaMA 13B on Kaggle 2x T4 and here's what I got:
It's definitely a .generate()
problem. I'll look into it and, hopefully, release a fix soon.
Thank you for sharing your findings on the performance of LLaMA 13B on Kaggle 2x T4. Good to know that you've identified the .generate() issue. I appreciate your efforts in looking into it and eagerly await the release of a fix. Keep up the good work!
Hi @BlackSamorez , have you been able to identify and fix the issue? I am having similar issues, where using 2 way or even 4 way tp slows down inference times, while using 2xA100 40GB w/ NVLINK
Would love to know if there is any update on this issue @BlackSamorez. tensor_parallel
works great for us for training (nice job!), but the inability to actually sample from the model is a dealbreaker for us. We're seeing slow generation for non-llama models too (e.g., Pythia-6.9b).
@eric-mitchell @dmgcsilva Sadly, I have no time nor resources to properly test and benchmark this right now. I'll do it in a month or so.
anyone find an alternative efficient TP solution yet?
Also found that 4gpus tp is much slower than 2gpus tp, while the latter is still a bit faster than 2*gpus pp.
This work is very meaningful. I followed @sgsdxzy and conducted the following test on 3090.
Model setup | opt-6.7b 1gpu | opt-6.7b 2gpu | opt-1.3b1gpu | opt-1.3b2gpu | opt-13b 4gpu |
---|---|---|---|---|---|
Naive per token time (ms) | 21.5 | 21.5(singal card) | 12.5 | 12.5(singal card) | 52.11 |
Naive memory per gpu (GB) | 12.8 | 12.8 | 2.9 | 2.9 | - |
TP time (ms) | - | 76.89 | - | 62.1 | 373.71 |
TP memory per gpu (GB) | - | 6.5 | - | 1.6 | 6.7GB |
But performance seems to be the same. Are there any other useful tensor parallel tools?
@dutsc I use Aphrodite-engine or vLLM for TP inference.
Thank you for your answer.
The inference speed of naive model parallel is much better than tensor parallel:
Setup: Llama-30b on 2080Ti 22G x4 Naive: 31.64s 4-way TP, main branch: 177.78s 4-way TP, llama branch: 102.22s
The code for naive inference
The code for TP: