microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
5.84k stars 990 forks source link

How to save memory during inference #791

Open Kangkang625 opened 8 months ago

Kangkang625 commented 8 months ago

Thanks for great work! When I run my inference code below using deepspeed --include localhost:0,1,2 inference.py --model opt-iml-30b --dataset WQSP I meet the error exits with return code = -9 , it seems before the model is split to GPUs because out of memory, the process is killed. When I watch htop, every process create a model of more than 60G, which exceeded my machine's memory.

Can the processes share one model in the memory instead of per process creating one model ? How should I save memory.

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed,AutoConfig
import torch
import json
import random 
import time
from tqdm import *
import os
import argparse 
import logging 
import deepspeed

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('--model',type=str)
parser.add_argument('--dataset',type=str)
parser.add_argument('--local_rank',type=str)
parser=parser.parse_args()

logging.info(f"inference using model: {parser.model} and dataset: {parser.dataset}")

f =open(os.path.join("datasets",parser.dataset,f"{parser.dataset}_all_question_with_label.json"),'r',encoding='utf-8')
dataset = json.load(f)
rank = os.environ['LOCAL_RANK']

if rank=='0':
    print('Begin load model...')

model = AutoModelForCausalLM.from_pretrained(parser.model, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(parser.model, use_fast=False)

model = deepspeed.init_inference(
    model=model,      
    mp_size=4,        
    dtype=torch.float16, 
    replace_method="auto", 
    replace_with_kernel_inject=True, 
)

time1=time.time()
n=len(dataset)
for i in tqdm(range(n)):
    prompt = dataset[i]['question']+'\nAnswer:'
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    logits = model.generate(input_ids, do_sample=False, max_length=100)
    ans=tokenizer.decode(logits[0].tolist())
    if rank=="0":
        ans = ans.split('\nAnswer:')[1]
        dataset[i][parser.model]=ans.strip("</s>")

time2=time.time())
Ekundayo39283 commented 2 months ago

It sounds like your model instantiation is causing memory issues. To save memory, you can try initializing the model once and then sharing it across processes. You might want to look into using multiprocessing or distributed training frameworks like PyTorch Distributed or Horovod for this purpose. Additionally, consider reducing the batch size or using smaller models if memory constraints persist.

Here is a more refined code

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed,AutoConfig
import torch
import json
import random 
import time
from tqdm import *
import os
import argparse 
import logging 
import deepspeed

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('--model',type=str)
parser.add_argument('--dataset',type=str)
parser.add_argument('--local_rank',type=str)
parser=parser.parse_args()

logging.info(f"inference using model: {parser.model} and dataset: {parser.dataset}")

f =open(os.path.join("datasets",parser.dataset,f"{parser.dataset}_all_question_with_label.json"),'r',encoding='utf-8')
dataset = json.load(f)
rank = os.environ['LOCAL_RANK']

if rank=='0':
    print('Begin load model...')

# Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained(parser.model, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(parser.model, use_fast=False)

# Initialize model with deepspeed for memory optimization
model, _, _, _ = deepspeed.init_inference(
    model=model,      
    mp_size=4,        
    dtype=torch.float16, 
    replace_method="auto", 
    replace_with_kernel_inject=True, 
)

time1 = time.time()
n = len(dataset)
for i in tqdm(range(n)):
    prompt = dataset[i]['question']+'\nAnswer:'
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    logits = model.generate(input_ids, do_sample=False, max_length=100)
    ans = tokenizer.decode(logits[0].tolist())
    if rank=="0":
        ans = ans.split('\nAnswer:')[1]
        dataset[i][parser.model] = ans.strip("</s>")

time2 = time.time()