Closed AlekseyKorshuk closed 1 year ago
I also am encountering the same issue which prevents us from using deepspeed in any setting with batch size > 1.
This currently blocks deployment in production setting.
I am using pytorch + python + GPTJ model.
with do_sample=False and batch size 4 with the same input e.g "what is the" I get the same response back 4 times when I dont use deepspeed.
but if I do deepspeed init as in the above example then the first response back is still the same as I would have got without using deepspeed but the other 3 responses are gibberish / bad / different to without deepspeed.
I might need to direct this issue to appropriate contributor for fast reply: @RezaYazdaniAminabadi, @cmikeh2.
@RezaYazdaniAminabadi Any updates regarding this issue?
Hi @AlekseyKorshuk,
This appears to have been fixed with https://github.com/microsoft/DeepSpeed/pull/2212 and I'm showing expected behavior with your reproduction script. Can you test with master to see if everything is working as expected on your end?
@cmikeh2 Looks like the issue is fixed after quick evaluation. Thank you! I will make a detailed assessment during this week and close the problem if everything is ok.
@cmikeh2 Trying to run the same code but with GPT-J-6B. And before inferencing got 13761MiB / 24564MiB on A5000. And while inferencing got the following:
Batch size: 4
Input size: torch.Size([4, 370])
Free memory : 10964500480 (Bytes) Total memory: 25434587136 (Bytes) Setting maximum total tokens (input + output) to 414
# Skipping some logs
RuntimeError: CUDA out of memory. Tried to allocate 568.00 MiB (GPU 0; 23.69 GiB total capacity; 11.57 GiB already allocated; 228.56 MiB free; 11.57 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
So 10GB of CUDA RAM is not enough to run batch size of 4 and only handle 414 tokens? I guess there is something wrong with memory reservation/allocation.
Here is updated version of code to run:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import deepspeed
from datasets import load_dataset
dataset = load_dataset("ChaiML/user_model_inputs")
model_id = "hakurei/litv2-6B-rev2"
VERBOSE = True
BATCH_SIZE = 4
GENERATION_KWARGS = {
"max_new_tokens": 64,
'eos_token_id': 198,
'do_sample': False,
'pad_token_id': 198,
'temperature': 0.72,
'top_k': 0,
'top_p': 0.725,
'repetition_penalty': 1.13,
}
torch_model = AutoModelForCausalLM.from_pretrained(model_id).half().eval().to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
tokenizer.pad_token_id = 198
def call_model(model, input_texts, desc="", verbose=False):
inputs = tokenizer(input_texts, return_tensors='pt', padding="longest").to(0)
if verbose:
print(desc)
print(f"Batch size: {len(input_texts)}")
print(f"Input size: {inputs.input_ids.size()}")
outputs = model.generate(**inputs, **GENERATION_KWARGS)
outputs_set = list()
output = None
for i, output in enumerate(outputs):
text_output = tokenizer.decode(output[len(inputs.input_ids[i]):], skip_special_tokens=True)
output = text_output
outputs_set.append(output)
if verbose:
print(f"#{i}: {output}")
return output
# call_model(
# model=torch_model,
# input_texts=dataset["train"]["text"][:4],
# desc="Deepspeed",
# verbose=VERBOSE
# )
ds_model = deepspeed.init_inference(
model=torch_model,
mp_size=1,
dtype=torch.float16,
replace_method="auto",
replace_with_kernel_inject=True,
)
call_model(
model=ds_model,
input_texts=dataset["train"]["text"][:4],
desc="Deepspeed",
verbose=VERBOSE
)
Moreover by changing back model_id to gpt2 outputs are gibberish / bad / different to without deepspeed.
Can you please try this PR to see if the issue is resolved? Thanks, Reza
@cmikeh2 @RezaYazdaniAminabadi Still have OOM for DeepSpeed while Torch can handle such input. Here is code to reproduce:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import deepspeed
model_id = "hakurei/litv2-6B-rev2"
# model_id = "gpt2"
VERBOSE = True
BATCH_SIZE = 4
GENERATION_KWARGS = {
"max_new_tokens": 64,
'eos_token_id': 198,
'do_sample': False,
'pad_token_id': 198,
'temperature': 0.72,
'top_k': 0,
'top_p': 0.725,
'repetition_penalty': 1.13,
}
torch_model = AutoModelForCausalLM.from_pretrained(model_id).half().eval().to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
tokenizer.pad_token_id = 198
def call_model(model, input_texts, desc="", verbose=False):
inputs = tokenizer(input_texts, return_tensors='pt', padding="longest").to(0)
if verbose:
print(desc)
print(f"Batch size: {len(input_texts)}")
print(f"Input size: {inputs.input_ids.size()}")
outputs = model.generate(**inputs, **GENERATION_KWARGS)
outputs_set = list()
output = None
for i, output in enumerate(outputs):
text_output = tokenizer.decode(output[len(inputs.input_ids[i]):], skip_special_tokens=False)
output = text_output
outputs_set.append(output)
if verbose:
print(f"#{i}: {output.strip()}")
return output
LONG_EXAMPLE = "User: How are you?\nBot: This is demo response, it is quite long to check everything.\n" * 40 + \
"User: How are you?\nBot:"
print(f"Long example length: {len(LONG_EXAMPLE)}")
input_texts = [LONG_EXAMPLE] * BATCH_SIZE
call_model(
model=torch_model,
input_texts=input_texts,
desc="Torch",
verbose=VERBOSE
)
torch_model.cpu()
torch.cuda.empty_cache()
ds_model = deepspeed.init_inference(
model=torch_model,
mp_size=1,
dtype=torch.float16,
replace_method="auto",
replace_with_kernel_inject=True,
)
call_model(
model=ds_model,
input_texts=input_texts,
desc="Deepspeed",
verbose=VERBOSE
)
Here is RuntimeError:
# Skipping logs...
Free memory : 10761076736 (Bytes) Total memory: 25434587136 (Bytes) Setting maximum total tokens (input + output) to 2296
# Skipping logs...
RuntimeError: CUDA out of memory. Tried to allocate 716.00 MiB (GPU 0; 23.69 GiB total capacity; 11.65 GiB already allocated; 146.56 MiB free; 11.65 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Hi @AlekseyKorshuk,
I pushed some fix that should prevent the OOM issue. Can you please try it on your side? Thanks, Reza
@RezaYazdaniAminabadi @cmikeh2 Sharing with you that after manual evaluation everything works fine. Going to AB test DeepSpeed inference a little bit later. I'm closing the issue for now and will reopen it if something doesn't work. Thank you for fixing all the problems.
Describe the bug Responses for transformers models are not relevant with long inputs and batch size > 1. This issue is related to gpt-like models, while this issue corresponds to Roberta. In the example code I used a small gpt2 model in order to test everything, while an issue exists with large models such as GPT-J, similar issue with OPT.
Please contact me for more details if needed. Or if I can help somehow with resolving this issue, this will be my pleasure to help.
To Reproduce Code to reproduce:
Ouput:
Expected behavior With
do_sample=False
expected to get all the same outputs with the same inputs in the batch. While only the first output is relevant and matches the torch output. Check the output section under the example code.ds_report output
System info (please complete the following information):
Launcher context
Docker context