haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.84k stars 2.18k forks source link

[Question] Use api #540

Open MartinGuo opened 1 year ago

MartinGuo commented 1 year ago

Question

How can I use API requests to call the model for the upper level application?

Raidus commented 1 year ago

The serve cli module might give you some hints.


        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                streamer=streamer,
                use_cache=True,
                stopping_criteria=[stopping_criteria])

        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        conv.messages[-1][-1] = outputs

Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/cli.py

userbox020 commented 1 year ago

im trying to modify the client to give the user the option to upload new images, but after upload the 3rd image i get oom error, im doing garbage colector and cleaning cache of vram but maybe im doing it wrong

import argparse
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image

import requests
import traceback
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

# Garbage Collector for PyTorch GPU Memory
import gc

def pytorch_gpu_garbage_collector():
    """Explicitly trigger Python's garbage collector and then clear PyTorch's GPU cache for all GPUs."""
    print("Triggering Python's garbage collector...")
    # Collect Python's garbage
    gc.collect()

    # Clear PyTorch's GPU cache for all GPUs
    if torch.cuda.is_available():
        for device_idx in range(torch.cuda.device_count()):
            torch.cuda.set_device(device_idx)
            print(f"Clearing PyTorch GPU cache for GPU {device_idx}...")
            torch.cuda.empty_cache()
    print("Garbage collection completed!")

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

    if 'llama-2' in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles

    while True:
        try:

            image_file = input('Image path:')
            #image = load_image(args.image_file)

            image = load_image(image_file)

            # Similar operation in model_worker.py
            image_tensor = process_images([image], image_processor, args)
            if type(image_tensor) is list:
                image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
            else:
                image_tensor = image_tensor.to(model.device, dtype=torch.float16)            

            inp = input(f"{roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        try:
            print(f"{roles[1]}: ", end="")

            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                conv.append_message(conv.roles[0], inp)
                image = None
            else:
                # later messages
                conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            print('------------------------')
            print('input_ids:', input_ids)
            print('image_tensor:', image_tensor)
            print('args.temperature:', args.temperature)
            print('max_new_tokens:', args.max_new_tokens)
            print('streamer:', streamer)
            print('stopping_criteria:', [stopping_criteria])

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=args.temperature,
                    max_new_tokens=args.max_new_tokens,
                    streamer=streamer,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria])

            input('Debug...')

            outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            conv.messages[-1][-1] = outputs
            print('------------------------')
            if args.debug:
                print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

            pytorch_gpu_garbage_collector()

        except Exception as e:
            print('Error:',str(e))
            print('Traceback:', traceback.format_exc())

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    #parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--image-aspect-ratio", type=str, default='pad')
    args = parser.parse_args()
    main(args)
TonyUSTC commented 11 months ago

请问解决了吗?遇到同样的困惑。

vicgarfield commented 11 months ago

The reason is that the request prompt will superimpose the previous image and text prompt token. Just re-initialize "conv" parameter every time you request it.

userbox020 commented 11 months ago

@vicgarfield thanks bro, working beautiful everything its cleaning now memory like it should be. Im on loop 9 and not accumulating any extra vram. Im running it with only 12gb vram. here i left code

import argparse
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image

import requests
import traceback
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

# Garbage Collector for PyTorch GPU Memory
import gc

def pytorch_gpu_garbage_collector():
    """Explicitly trigger Python's garbage collector and then clear PyTorch's GPU cache for all GPUs."""
    print("Triggering Python's garbage collector...")
    # Collect Python's garbage
    gc.collect()

    # Clear PyTorch's GPU cache for all GPUs
    if torch.cuda.is_available():
        for device_idx in range(torch.cuda.device_count()):
            torch.cuda.set_device(device_idx)
            print(f"Clearing PyTorch GPU cache for GPU {device_idx}...")
            torch.cuda.empty_cache()
    print("Garbage collection completed!")

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

    if 'llama-2' in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    while True:
        try:

            conv = conv_templates[args.conv_mode].copy()
            if "mpt" in model_name.lower():
                roles = ('user', 'assistant')
            else:
                roles = conv.roles

            image_file = input('Image path:')
            #image = load_image(args.image_file)

            image = load_image(image_file)

            # Similar operation in model_worker.py
            image_tensor = process_images([image], image_processor, args)
            if type(image_tensor) is list:
                image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
            else:
                image_tensor = image_tensor.to(model.device, dtype=torch.float16)            

            inp = input(f"{roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        try:
            print(f"{roles[1]}: ", end="")

            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                conv.append_message(conv.roles[0], inp)
                image = None
            else:
                # later messages
                conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            print('------------------------')
            print('input_ids:', input_ids)
            print('image_tensor:', image_tensor)
            print('args.temperature:', args.temperature)
            print('max_new_tokens:', args.max_new_tokens)
            print('streamer:', streamer)
            print('stopping_criteria:', [stopping_criteria])

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=args.temperature,
                    max_new_tokens=args.max_new_tokens,
                    streamer=streamer,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria])

            input('Debug...')

            outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            conv.messages[-1][-1] = outputs
            print('------------------------')
            if args.debug:
                print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

            # Assuming you're done with output_ids and other tensors, start cleanup
            print("Deleting output_ids")
            del output_ids
            print("Deleting input_ids")
            del input_ids
            print("Deleting image_tensor")
            del image_tensor
            print("Deleting inp")
            del inp
            print("Deleting stop_str")
            del stop_str
            #print('Deleting conv')
            #del conv
            #print('Deleting model')
            #del model

            # --- MEMORY CLEANER            
            pytorch_gpu_garbage_collector()

        except Exception as e:
            print('Error:',str(e))
            print('Traceback:', traceback.format_exc())

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    #parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--image-aspect-ratio", type=str, default='pad')
    args = parser.parse_args()
    main(args)

To run the model:

CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.tcli --model-path /media/10TB_HHD/_LLaVA/_models/liuhaotian_llava-v1.5-13b/ --load-4bit
vicgarfield commented 11 months ago

@userbox020 Perfect!

pbenaim commented 11 months ago

And tcli.py must be in ../LLaVA/llava/serve :+1:

Thanks

yinjiaoyuan commented 11 months ago

VERY GOOD!!!

hugsbrugs commented 10 months ago

Hey, I'm also trying to call to model through API but I can't figure out how you do your CURL (post ?) request once the model is loaded ? You don't attach it to a controller or gradio web server ? so how to call it publicly ? I've tried to append --share but this option is not handled with tcli.py Thanks for your help !

Xingxiangrui commented 8 months ago

Hey, I'm also trying to call to model through API but I can't figure out how you do your CURL (post ?) request once the model is loaded ? You don't attach it to a controller or gradio web server ? so how to call it publicly ? I've tried to append --share but this option is not handled with tcli.py Thanks for your help !

+1 for this

userbox020 commented 8 months ago

Hey, I'm also trying to call to model through API but I can't figure out how you do your CURL (post ?) request once the model is loaded ? You don't attach it to a controller or gradio web server ? so how to call it publicly ? I've tried to append --share but this option is not handled with tcli.py Thanks for your help !

+1 for this

Sorry bro i confuse post, it wasnt an API it just some tricks to don't oom with 8 gb gpus.

But adding API will be easy enought, you can create a separate script and use the module subprocess and save the output or stream it. Then use any python http server to publish or stream the output.

Chatgpt must be able to do a good code template for that