X-PLUG / mPLUG-Owl

mPLUG-Owl: The Powerful Multi-modal Large Language Model Family
https://www.modelscope.cn/studios/damo/mPLUG-Owl
MIT License
2.25k stars 171 forks source link

Why does it take 30 minutes for a model to generate text #125

Closed Zhoues closed 1 year ago

Zhoues commented 1 year ago

device: a A100-80G GPU code: 8 time: 7

Are there good friends also confusing about this problem?

FuxiaoLiu commented 1 year ago

device: a A100-80G GPU code: 8 time: 7

Are there good friends also confusing about this problem?

I also confused about this at first. I edit the code a little bit (link) so that we can run inference in batch more quickly.

Zhoues commented 1 year ago

device: a A100-80G GPU code: 8 time: 7 Are there good friends also confusing about this problem?

I also confused about this at first. I edit the code a little bit (link) so that we can run inference in batch more quickly.

Thank you very much!! I have solved this question yesterday, so I closed this issue. Anyway, thank you from the bottom of my heart!!

srivivtcs commented 1 year ago

Hi, could you please share the method you used to speedup the inference?

Zhoues commented 1 year ago

Hi, could you please share the method you used to speedup the inference?

Hello,I am very glad to share my method with you. The only thing you should do is to change the Tokenizer.

from:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt)

to

from transformers.models.llama.tokenization_llama import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)

If this works for you, you can click on 👍

srivivtcs commented 1 year ago

Thanks, the suggested change is not giving any significant speedup in the inference. Is there anything else that I can try to speedup the inference?

edchengg commented 1 year ago

facing the same issue. the proposed fix is not working for me.

Zhoues commented 1 year ago

@srivivtcs @edchengg Hello! Because I do not have any detail about your problem, I would like to share my inference code, which only needs about 1 min to load model and 10s to generate text.

model_worker.py

class mPLUG_Owl_Server:
    def __init__(
        self, 
        base_model='MAGAer13/mplug-owl-llama-7b',
        log_dir='./',
        load_in_8bit=False,
        bf16=True,
        device="cuda",
        io=None,
        use_lora=False,
        inference_mode=False,
        lora_r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        delta_path='none',
    ):
        self.log_dir = log_dir
        self.image_processor = MplugOwlImageProcessor.from_pretrained(base_model)
        self.tokenizer = LlamaTokenizer.from_pretrained(base_model)
        self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)
        self.model = MplugOwlForConditionalGeneration.from_pretrained(
            base_model,
            load_in_8bit=load_in_8bit,
            torch_dtype=torch.bfloat16 if bf16 else torch.half,
            device_map="auto",
        )

        # LoRA part
        if use_lora:
            assert delta_path!='none', "If you use lora, please input lora path"
            # Load LoRA Config and Param
            peft_config = LoraConfig(
                target_modules=r'.*language_model.*\.(q_proj|v_proj)', 
                inference_mode=inference_mode, 
                r=lora_r, 
                lora_alpha=lora_alpha, 
                lora_dropout=lora_dropout
            )
            self.model = get_peft_model(self.model, peft_config)

            print('load lora from {}'.format(delta_path))
            prefix_state_dict = torch.load(delta_path, map_location='cpu')
            self.model.load_state_dict(prefix_state_dict, strict=False)

        self.tokenizer = self.processor.tokenizer
        self.bf16 = bf16
        self.load_in_8bit = load_in_8bit

        if not load_in_8bit:
            if bf16:
                self.model.bfloat16()
            else:
                self.model.half()
        self.model.eval()
        ......

inference_2d.py (I create)

print(">> begin loading pretrained model")
t1 = datetime.datetime.now()
model = mPLUG_Owl_Server(
        base_model=args.pretrained_ckpt,
        load_in_8bit=not args.bf16,
        bf16=args.bf16,
        use_lora=args.use_lora,
        inference_mode=args.inference_mode,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        delta_path=args.delta_path,
    )
  t2 = datetime.datetime.now()
  print('load pretrained model time: ',t2 - t1)

  prompts = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
  Human: <image>
  Human: ....
  AI: '''

  image_list =[...]

  data_format = {
      "text_input": prompts,
      "images": image_list if len(image_list) > 0 else [],
      "generation_config": {
      "top_k": int(3),
      "top_p": float(0.9),
      "num_beams": int(3),
      "no_repeat_ngram_size": int(2),
      "length_penalty": float(1.0),
      "do_sample": bool(True),
      "temperature": float(1.0),
      "max_new_tokens": 512,
      "early_stopping":bool(False)}}

   t3 = datetime.datetime.now()
  sentence = model.predict_parallel(data_format)
  t4 = datetime.datetime.now()
  print('Forward time: ',t4 - t3)

  print(sentence)

If this works for you, you can click on 👍

edchengg commented 1 year ago

Thanks a lot, Zhoues! Switching to worker_server is working well for me.

srivivtcs commented 1 year ago

@Zhoues, I am getting the following error while running the code snippet you shared. AttributeError: 'mPLUG_Owl_Server' object has no attribute 'predict_parallel'

Zhoues commented 1 year ago

@Zhoues, I am getting the following error while running the code snippet you shared. AttributeError: 'mPLUG_Owl_Server' object has no attribute 'predict_parallel'

@srivivtcs Hi, you can just change the predict_parallel into predict. Because predict method in model_server.py can only handle individual data, not batch processing. I can also share my code of predict_parallel with you.

def predict_parallel(self, data):
    prompt = [data['text_input']]
    images = data['images'] if len(data['images']) > 0 else None

    if images:
        images = [Image.open(image) for image in images]
    inputs = self.processor(text=prompt, images=images, return_tensors='pt')
    input_ids = inputs['input_ids'].to(self.model.device)
    if 'pixel_values' in inputs:
        if self.load_in_8bit:
            pixel_values = inputs['pixel_values'].half().to(self.model.device)
        elif self.bf16:
            pixel_values = inputs['pixel_values'].bfloat16().to(self.model.device)
        else:
            pixel_values = inputs['pixel_values'].half().to(self.model.device)
    else:
        pixel_values = None

    with torch.no_grad():
        generation_output = self.model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            return_dict_in_generate=True,
            output_scores=True,
            **data['generation_config']
        )
    s = generation_output.sequences[0].cpu()
    output = self.tokenizer.decode(s)
    output = post_process_output(output)
    return output

If this works for you, you can click on 👍