Meituan-AutoML / MobileVLM

Strong and Open Vision Language Assistant for Mobile Devices
Apache License 2.0
969 stars 65 forks source link

Model Misclassification (All Images Classified as NSFW) #59

Open mehranfvs opened 4 weeks ago

mehranfvs commented 4 weeks ago

Description :

I've been using the "MobileVLMv2-1.7B" model for a task where I need to classify images as either "NSFW" or "SFW". However, the model consistently classifies all images as NSFW, regardless of their actual content. This behavior persists even with various prompt modifications and different image inputs.

Reproduce: prompt = "Is this picture sfw or nsfw?\nAnswer the question using a single word of nsfw or sfw"

Expected Behavior: The model should classify images accurately as SFW or NSFW based on the content.

Actual Behavior: All images are classified as NSFW, even those that are clearly SFW.

I made some changes to the script including encapsulating functionalities within a class and ... However, the overall workflow and logic of the code remain consistent with the original version. https://github.com/Meituan-AutoML/MobileVLM/blob/main/scripts/inference.py

my code :

import sys import os import torch import argparse from PIL import Image from pathlib import Path from time import time import transformers import warnings

  transformers.logging.set_verbosity_error()
  warnings.filterwarnings('ignore')

  sys.path.append(str(Path(__file__).parent.parent.resolve()))
  from mobilevlm.model.mobilevlm import load_pretrained_model
  from mobilevlm.conversation import conv_templates, SeparatorStyle
  from mobilevlm.utils import disable_torch_init, process_images, tokenizer_image_token, KeywordsStoppingCriteria
  from mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

  class inference_once():

      def __init__(self):
          pass

      def load_model(self,model_path,load_8bit=False,load_4bit=False):
          model_name = model_path.split('/')[-1]
          self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(model_path,
                                                                                                load_8bit,
                                                                                                load_4bit)
          disable_torch_init()

      def load_image(self,image_file):
          images = [(Image.open(image_file).convert("RGB")).resize((512, 512), Image.ANTIALIAS)]
          self.images_tensor = process_images(images,
                                         self.image_processor,
                                         self.model.config).to(self.model.device,
                                                              dtype=torch.float16)

      def infer(self,args):
          conv = conv_templates[args.conv_mode].copy()
          conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + args.prompt)
          conv.append_message(conv.roles[1], None)

          prompt = conv.get_prompt()
          stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2

          # Input
          self.input_ids = (tokenizer_image_token(prompt,
                                             self.tokenizer,
                                             IMAGE_TOKEN_INDEX,
                                             return_tensors="pt").unsqueeze(0).cuda())

          self.stopping_criteria = KeywordsStoppingCriteria([stop_str],
                                                       self.tokenizer,
                                                       self.input_ids)

          # Inference
          with torch.inference_mode():
              output_ids = self.model.generate(
                  self.input_ids,
                  images= self.images_tensor,
                  do_sample= True if args.temperature > 0 else False,
                  temperature= args.temperature,
                  top_p= args.top_p,
                  num_beams= args.num_beams,
                  max_new_tokens= args.max_new_tokens,
                  use_cache=True,
                  stopping_criteria=[self.stopping_criteria])

          # Result-Decode
          input_token_len = self.input_ids.shape[1]
          n_diff_input_output = (self.input_ids != output_ids[:, :input_token_len]).sum().item()

          if n_diff_input_output > 0:
              print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")

          outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
                                                  skip_special_tokens=True)[0]
          outputs = outputs.strip()
          if outputs.endswith(stop_str):
              outputs = outputs[: -len(stop_str)]
          # print(f"šŸš€ {model_name}: {outputs.strip()}\n")
          return outputs

  if __name__ == "__main__":
      model_path = "/MobileVLM_V2-1.7B"
      image_folder = "/samples"
      prompt = "Is this picture sfw or nsfw?\nAnswer the question using a single word of nsfw or sfw"

      mvlm2 = inference_once()
      s_l = time()
      mvlm2.load_model(model_path)
      e_l = time()- s_l
      print(f"load time : {e_l:.2f} seconds")

      images = os.listdir(image_folder)
      for img in images:
          img_path = os.path.join(image_folder, img)
          print(img)

          s = time()

          args = type('Args', (), {
          "prompt": prompt,
          "conv_mode": "v1",
          "temperature": 0, 
          "top_p": None,
          "num_beams": 1,
          "max_new_tokens": 512
          })()

          mvlm2.load_image(img_path)
          output = mvlm2.infer(args)

          e = time()-s

          print("\033[92m", 50 * "-")
          print(f"time : {e:.2f} seconds")
          print(f"šŸš€ output {img} : {output.strip()}")
          print(50 * "-", "\033[0m")

output :

Screenshot from 2024-08-18 12-18-29