Closed chhzh123 closed 1 year ago
Actually, it seems the problem comes with long sequence lengths. After I changed it to 1024, it can run without getting into errors.
Actually, it seems the problem comes with long sequence lengths. After I changed it to 1024, it can run without getting into errors.
Hi @chhzh123,
Yes, that is the default max_out_tokens that we reserve as the KV-cache and if you want to produce more tokens, you need to increase it, which you can simply do that by passing the max_out_tokens=2048
, at the init_inference
call. Here is the config for this parameter: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/inference/config.py#L246
Thanks, Reza
@RezaYazdaniAminabadi Thanks for your quick reply! It worked after I changed the max_out_tokens
to 2048, but I think it would be better to error out in this case. Otherwise, users won't be able to know what the exact issue is.
Also, what would be the correct way to load large models for inference? I tried the following code, but it failed when initializing the inference engine.
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
mod = OPTModel(config)
File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 500, in _replace_module
_replace_module(child, name, class_name)
File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 496, in _replace_module
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 401, in _replace
data = mp_replace.copy(new_weight, child.weight.data)
File "/home/ubuntu/DeepSpeed/deepspeed/module_inject/replace_module.py", line 96, in copy
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
AssertionError
Hi @chhzh123,
Thanks for the suggestion. Yes, we need to make this more informative. Also, for loading large model, I suggest that you use this script that we have on the DSE side for the inference: https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py Please let me know if you can load your model using this. Thanks, Reza
Describe the bug It seems the latest master branch has some problems with the inference kernels when running with multiple GPUs. After I tested several models including OPT/LLaMA/Vicuna, none of them can be correctly executed with batch size = 1.
To Reproduce Steps to reproduce the behavior:
import torch import torch.distributed as dist import deepspeed from transformers import OPTModel, AutoConfig
dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"]))
config = AutoConfig.from_pretrained("facebook/opt-1.3b") config.use_cache = False mod = OPTModel(config) mod.eval() mod.to(torch.float16) bs, seq_len = 1, 2048 input_ids = torch.ones( bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" )
Initialize the DeepSpeed-Inference engine
ds_engine = deepspeed.init_inference( mod, mp_size=dist.get_world_size(), dtype=torch.float16, checkpoint=None, replace_with_kernel_inject=True, ) mod = ds_engine.module mod(input_ids)
Expected behavior This script should run without encountering errors.
ds_report output Please run
ds_report
to give us details about your setup.System info (please complete the following information):