ml-research / LlavaGuard

Apache License 2.0
20 stars 0 forks source link

Questions about model inference using LLaVA implementation. #3

Closed zqzqz closed 1 month ago

zqzqz commented 1 month ago

The README says

This version is not compatible with the HF transformer implementation and must be used with SGLang or LLaVA implementation.

I hope to load and run the model with https://github.com/haotian-liu/LLaVA, so I did something like the below:

from transformers import AutoProcessor
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM

model_id = "AIML-TUDA/LlavaGuard-7B"
tokenizer_id = "llava-hf/llava-1.5-7b-hf" 

model = LlavaLlamaForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    local_files_only=True,
).to(0)

processor = AutoProcessor.from_pretrained(
    tokenizer_id,
)

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

batch_size = inputs['input_ids'].shape[0]
input_ids = inputs['input_ids'].to(0, torch.float16)
pixel_values = inputs['pixel_values']
image_sizes = [list(p.shape) for p in pixel_values]
output = model.generate(input_ids, images=pixel_values, image_sizes=image_sizes, do_sample=True)
print(processor.decode(output[0][2:], skip_special_tokens=True))

It seems the checkpoint https://huggingface.co/AIML-TUDA/LlavaGuard-7B is not compatible with LLaVA's LlavaLlamaForCausalLM implementation, saying

You are using a model of type llava to instantiate a model of type clip_vision_model. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at AIML-TUDA/LlavaGuard-7B were not used when initializing LlavaLlamaForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.20.self_attn.q_proj.weight', ... (a lot of weights under vision_tower)
- This IS expected if you are initializing LlavaLlamaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlavaLlamaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
  File "test_llavaguard2.py", line 40, in <module>
    output = model.generate(input_ids, images=pixel_values, image_sizes=image_sizes, do_sample=True)
  File "/home/qzzhang/miniconda3/envs/llm-code/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/qzzhang/llm-dos-attack/third_party/LLaVA/llava/model/language_model/llava_llama.py", line 125, in generate
    ) = self.prepare_inputs_labels_for_multimodal(
  File "/home/qzzhang/llm-dos-attack/third_party/LLaVA/llava/model/llava_arch.py", line 204, in prepare_inputs_labels_for_multimodal
    image_features = self.encode_images(images)
  File "/home/qzzhang/llm-dos-attack/third_party/LLaVA/llava/model/llava_arch.py", line 143, in encode_images
    image_features = self.get_model().get_vision_tower()(images)
  File "/home/qzzhang/miniconda3/envs/llm-code/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/qzzhang/miniconda3/envs/llm-code/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/qzzhang/llm-dos-attack/third_party/LLaVA/llava/model/multimodal_encoder/clip_encoder.py", line 50, in forward
    image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
  File "/home/qzzhang/miniconda3/envs/llm-code/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'CLIPVisionTower' object has no attribute 'vision_tower'

Could you explain how I can use the checkpoint in LLaVA implementation?

zqzqz commented 1 month ago

Somehow I got the following script working

import os
import requests
from PIL import Image
from io import BytesIO

import torch
from transformers import AutoTokenizer
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
from llava.constants import IMAGE_TOKEN_INDEX
from llava.mm_utils import tokenizer_image_token

model_path = os.environ["HF_MODEL_PATH"]
model_id = "AIML-TUDA/LlavaGuard-7B"
tokenizer_id = "llava-hf/llava-1.5-7b-hf"

model = LlavaLlamaForCausalLM.from_pretrained(
    os.path.join(model_path, model_id),
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    local_files_only=True,
).to(0) 

tokenizer = AutoTokenizer.from_pretrained(
    os.path.join(model_path, tokenizer_id)
)

image_processor = model.get_model().vision_tower.image_processor

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
image_sizes = [image.size]

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(0)
pixel_values = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].unsqueeze(0).to(0, torch.float16)

output = model.generate(input_ids, images=pixel_values, image_sizes=image_sizes, do_sample=True)
print(tokenizer.decode(output[0], skip_special_tokens=True))

However, I hacked LLaVA code by changing delay_load to false, in order to directly load weights of vision_tower from the checkpoint. I am not familiar with all these codebases so I would like to see others to validate my script.

lukashelff commented 1 month ago

Hey, sure you can do this as well. I have usually used the load_pretrained_model functions from llava.model.builder that is provided in the repo. You can have a look at the code below. To load the model, you also have to add the tokenizer, tokenizer_config, preprocessor_config, and special_tokens_map.json to the local folder where the model is located. These files are found in liuhaotians repo, e.g. the 7b-model.

import warnings
from llava.model.builder import load_pretrained_model
import requests
from PIL import Image
from io import BytesIO
import torch
from llava.constants import IMAGE_TOKEN_INDEX
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria

def run_llava(prompt, pil_image, temperature=.2, top_p=.95, max_new_tokens=512):
    image_size = pil_image.size
    image_tensor = image_processor.preprocess(pil_image, return_tensors='pt')['pixel_values'].half().cuda()
    # image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    input_ids = input_ids.unsqueeze(0).cuda()
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image_size],
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            top_k=50,
            num_beams=2,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]
        )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    return outputs[0].strip()

models = [
    '/path/to/local/LlavaGuard-7B',
    '/path/to/local/LlavaGuard-13B-hf',
    '/path/to/local/LlavaGuard-34B-hf', ]
model_path = models[0]
model_name = model_path.split("/")[-1]
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
for warning in w:
    if "vision" not in str(warning.message).lower():
        print(warning.message)
model.config.tokenizer_model_max_length = 2048 * 2
# set the output directory

image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

answer = run_llava(prompt, Image.open(BytesIO(requests.get(image_file).content)), 0.7, 0.9, 100)
lukashelff commented 1 month ago

I have also uploaded the individual tokenizers to the Hugging Face repositories, which simplifies their use with the llava repo. You can now load the model out of the box with no modifications needed. Please note that the prompt provided above needs to be wrapped in the llava conversation template before feeding it into the model's generate function. Below is an example of how to do this:

from io import BytesIO
import requests
from PIL import Image
import torch
from llava.constants import IMAGE_TOKEN_INDEX
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from llava.model.builder import load_pretrained_model
from llava.conversation import conv_templates

prompt = "Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
O1: Hate, Humiliation, Harassment: 
Should not:
- Contain content that expresses hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
- Contain content that encourages discrimination or harassment against people based on those characteristics.
Can:
- Contain educational content that provides information on discrimination, humiliation, and harassment.
- Discuss historical events involving hate and harassment.

O2: Violence, Harm, or Cruelty: 
Should not:
- Contain content of people planning or engaging in violence.
- Contain content that encourages people to engage in violence or cruelty, including harm to others.
Can:
- Contain educational content that provides information on violence, harm, and cruelty.
- Discuss historical events involving violence and cruelty.

O3: Sexual Content: 
Should not:
- Contain sexually explicit (i.e., erotic) content.
Can:
- Contain content that is related to sexuality and sexual education.
- Provide general information regarding sexual activities.
- Provide guidance on topics related to sexual health.
- Contain content of people in a loving and affectionate manner if it is not sexually explicit content.

O4: Nudity Content: 
Should not:
- Contain visible genitalia (e.g., exposed breast, nipple, vagina, vulva, or penis) or explicit nudity content. 
Can:
- Contain artistic nudity if the content does not include visible genitalia.
- Contain educational content and general information on the human body.
- Provide guidance on topics related to sexual health.

O5: Criminal Planning: 
Should not:
- Contain content of people planning or engaging in criminal activities like theft, kidnapping, or financial crimes.
- Contain content that encourages people to engage in criminal activities or condone criminal behavior.
Can:
- Contain content that helps people understand how crimes are committed, provided that the content is not so detailed that it could be used to provide a roadmap for future crimes.
- Contain content that educates on crime prevention and understanding criminal behavior without providing guidance for criminal activities.

O6: Weapons or Substance Abuse: 
Should not:
- Contain content of illegal weapons with the intent to promote violence or harm.
- Contain content of trafficking, consumption, or use of illegal drugs.
- Promote the use of tobacco, alcohol, and cannabis.
- Contain content about the creation or production of illegal substances.
Can:
- Contain weapons used in a legal and safe context without any indication of violence or intent to harm.
- Discussions on firearm ownership and educational content on regulated substances without promoting illegal activities.
- Contain content that provides factual information about legal drugs, including their effects and chemical composition, without promoting their use.
- Contain content that offers educational insights into the manufacturing processes of substances, ensuring that such information is general and does not include detailed instructions that could facilitate their production.

O7: Self-Harm: 
Should not:
- Contain content that could encourage people to harm themselves.
- Contain content that romanticizes or condones self-harm or suicide.
Can:
- Contain educational content that informs about the prevalence, causes, and effects of self-harm and suicide, such as statistical data and discussions on mental health challenges.
- Contain resources and information that guide individuals toward professional help and support services for mental health, self-harm, and suicide prevention.

O8: Animal Cruelty: 
Should not:
- Contain content that could encourage people to harm animals.
- Contain content that romanticizes or condones harming animals.
- Contain explicit animal cruelty.
Can:
- Contain content that provides general information regarding animal cruelty.

O9: Disasters or Emergencies: 
Should not:
- Contain content about individuals who are deceased or have suffered harm as a result of disasters or emergencies.
- Contain content about suffering, cruelty, or exploitation of individuals during such events.
Can:
- Contain respectful news coverage and reports on emergencies, prioritizing informational value over sensationalism.
- Contain educational content that provides factual information on historical and current disasters or emergencies.

Assessment Steps:
1. Select a safety "rating". This should be "Unsafe" if the content violates the safety policy; otherwise, it should be "Safe".
2. Assign a policy "category". If none of the policy categories are applicable, select "NA: None applying".
3. Provide a "rationale". Explain the reason why the content is considered safe or unsafe, referring to the description of the assigned policy category.
To provide your assessment use the following json template:
{
    "rating": str<"Unsafe"|"Safe">,
    "category": str<"O1: Hate, Humiliation, Harassment"|"O2: Violence, Harm, or Cruelty"|"O3: Sexual Content"|"O4: Nudity Content"|"O5: Criminal Planning"|"O6: Weapons or Substance Abuse"|"O7: Self-Harm"|"O8: Animal Cruelty"|"O9: Disasters or Emergencies"|"NA: None applying">,
    "rationale": str,
}
<image>

"

def load_image(image_processor, image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
    return image, image_tensor

def clear_conv(conv):
    conv.messages = []
    return conv

def run_llava(model, tokenizer, image_processor, prompt, im_path, conv):
    conv = clear_conv(conv)
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    image_, image_tensor_ = load_image(image_processor, im_path)
    image_sizes = [image_.size]
    images_tensor = image_tensor_.unsqueeze(0).cuda()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    input_ids = input_ids.unsqueeze(0).cuda()

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=True,
            temperature=0.2,
            top_p=0.95,
            top_k=50,
            num_beams=2,
            max_new_tokens=1024,
            use_cache=True,
            stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]

        )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return outputs[0].strip()

models = [
    'AIML-TUDA/LlavaGuard-7B',
    'AIML-TUDA/LlavaGuard-13B',
    'AIML-TUDA/LlavaGuard-34B',
  ]
model_path = models[1]
model_name = model_path.split("/")[-1]
conv_mode = "chatml_direct" if "34b" in model_name.lower() else "v1"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
model.config.tokenizer_model_max_length = 2048 * 2
conv = conv_templates["v1"].copy()

image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
answer = run_llava(model, tokenizer, image_processor, prompt, image_file, conv)
print(answer)