Closed McCarrtney closed 1 year ago
It seems that baichuan is not supported yet, you can refer to this repository: https://github.com/gameofdimension/vllm-cn
Or refer to this document: https://vllm.readthedocs.io/en/latest/models/adding_model.html
Thank you! That's quite helpful
Add this environment variable (replace with your modules directory) can make it work, but the results generated by the model are completely incorrect.
PYTHONPATH=/root/.cache/huggingface/modules
this is caused by: https://github.com/vllm-project/vllm/blob/58a072be15a4e36bee006d1c9a962e527819cf18/vllm/engine/llm_engine.py#L148
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
self.model_config,
self.parallel_config,
...
)
this lambda function will capture self object, which has a reference to config from remote huggingface modules
you can try this to avoid capture self object:
import copy
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
model_config,
parallel_config,
scheduler_config,
None,
None,
))
for baichuan, please note that baichuan-7b uses rotary embedding while baichuan-13b uses alibi, you can refer this link https://github.com/vllm-project/vllm/pull/512
@mklf Can you get normal output using biachuan-7b in distribute mode? In my environment MPT-7b works fine in distribute mode, but biachuan-7b/13b always return garbage output
baichuan-7b normal output using 1 gpu
baichuan-7b garbage output using tensor_parallel_size=2
no , in this link https://github.com/vllm-project/vllm/pull/512 they mentioned:
Our code is currently only compatible with non-distributed deployments, i.e., setups involving a single GPU and single model.
While our code is operational with distributed deployment using tensor parallelism, the results it produces are not yet accurate. We are actively looking for community help to rectify this issue.
@Sanster I fixed the bug, here is the updated code, you can replace load_weights
in https://github.com/vllm-project/vllm/blob/58a072be15a4e36bee006d1c9a962e527819cf18/vllm/model_executor/models/baichuan.py#L259 with the following code(tested in baichuan13b):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "W_pack.weight" in name: # <----- newly added code here
head_size = self.config.hidden_size // self.config.num_attention_heads
loaded_weight = (
loaded_weight.contiguous()
.view(
3,
self.config.num_attention_heads,
head_size,
self.config.hidden_size,
)
.transpose(0, 1)
.contiguous()
.view(-1, self.config.hidden_size)
)
shard_size = param.shape[0]
start = shard_size * tensor_model_parallel_rank
end = shard_size * (tensor_model_parallel_rank + 1)
loaded_weight = loaded_weight[start:end]
loaded_weight = loaded_weight.view(
-1, 3, head_size, self.config.hidden_size
)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, self.config.hidden_size)
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
We have tried baichaun 7B API serving on a single GPU, it's ok and the generation is good. But the speedup is 4x slower than LLaMA 7B. Is there anything that decelerate baichuan inferencing?
@mklf Thank you for your reminder. I have also found the issue in "load_weights". The method I used to fix it is similar to yours. Here is my implementation: https://github.com/vllm-project/vllm/pull/598
Fixed by #599
I have the same issue when serving local internlm-7b model with tensor_parallel=4. Any ideas?
I have the same issue when serving local internlm-7b model with tensor_parallel=4. Any ideas?
This pull request https://github.com/vllm-project/vllm/pull/871 will solve the problem.
Its a long pending issue in transformer code. To fix it, never use the model_name having periods(.) in the model name if using trust_remotecode feature, change the name to have either underscore() or any other sysmbol
workaround for me was to switch to multiprocessing (disable Ray) and remove '.' from the model name (path) if present, before instantiating the engine. e.g. "weights/Phi3.5mini-instruct" -> "weights/Phi35mini-instruct"
I tried to deploy an API serving using baichuan-7b, but there is an error: