import requests
import torch
from PIL import Image
from pathlib import Path
import numpy as np
from cobra import load
import time
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
print("downloading image from url", args.video_file)
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_transform,image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image_transform(image))
return out
can you tell me the shape of input size?
this is my code:
import requests import torch from PIL import Image from pathlib import Path import numpy as np from cobra import load import time def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): print("downloading image from url", args.video_file) response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image
def load_images(image_transform,image_files): out = [] for image_file in image_files: image = load_image(image_file) out.append(image_transform(image)) return out
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model_id = "cobra+3b" vlm = load(model_id) vlm.to(device, dtype=dtype) image_transform = vlm.vision_backbone.image_transform
image_file=["test_0.jpg","test_1.jpg","test_2.jpg","test_3.jpg","test_4.jpg","test_5.jpg"] image =load_images(image_transform,image_file) user_prompt = "is there fire or smoke(yes/no)?" prompt_builder = vlm.get_prompt_builder() prompt_builder.add_turn(role="human", message=user_prompt) prompt_text = prompt_builder.get_prompt() prompt_text=[prompt_text]*len(image_file) generated_text = vlm.generate_batch( image, prompt_text, use_cache=True, do_sample=False, temperature=0.1, max_new_tokens=1, ) print("answer:",generated_text)