TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
327 stars 93 forks source link

How to extract the last hidden outputs from the models? #1

Closed swj0419 closed 4 months ago

swj0419 commented 4 months ago

Thank you so much for open sourcing the code and models. It is very helpful! I was wondering how to extract the last hidden outputs from the models.

siddk commented 4 months ago

Great question! Just like a HF Transformers LLM, you can pass output_hidden_states=True to forward() for our VLMs as shown here.

This will return a CausalLMOutputWithPast object - if you check the hidden_states attribute, it'll have a all the hidden states for every transformer block of the LLM backbone (including embeddings)!

swj0419 commented 4 months ago

Thank you for the prompt reply! Can I try something like:


image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
user_prompt = "What is going on in this image?"

# Build prompt
prompt_builder = vlm.get_prompt_builder()
prompt_builder.add_turn(role="human", message=user_prompt)
prompt_text = prompt_builder.get_prompt()
image_transform, tokenizer = vlm.vision_backbone.image_transform, vlm.llm_backbone.tokenizer
input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(device)
pixel_values = image_transform(image)
pixel_values = {k: v[None, ...].to(device) for k, v in pixel_values.items()}
feature = vlm.forward(input_ids = input_ids, pixel_values=pixel_values,  output_hidden_states=True)
siddk commented 4 months ago

If you just want the visual features this will work! However, it won't generate a full text response (this would only run a single forward pass to generate the next token).