paperswithlove / papers-we-read

3 stars 0 forks source link

When Do We Not Need Larger Vision Models? (from 현준님) #6

Open runhani opened 3 months ago

runhani commented 3 months ago

image

code : https://github.com/bfshi/scaling_on_scales paper : https://arxiv.org/abs/2403.13043

좋은 질문에서 시작된 논문

Is a larger model always necessary for better visual understanding?

이게 지금은 대세인가? 원본 --> 5개로 나눠서 넣기!

Quickstart : 코드도 So Simple (Any pre-trained vision model)

Step 1. Clone this repo and install s2wrapper through pip.

# go to the directory of this repo, and run 
pip install .

Step 2. Extract multi-scale feature on any vision model with one line of code.

Assume you have a function (could be model, model.forward, etc.) that takes in BxCxHxW images and outputs BxNxC features.

For example, you have model (e.g., ViT-B) that extracts feature by

feature = model(x)   # e.g., x: 32*3*224*224, feature: 32*196*768

Then extract multi-scale features (e.g., scales of 1 and 2) by

from s2wrapper import forward as multiscale_forward
mutliscale_feature = multiscale_forward(model, x, scales=[1, 2])   # x: 32*3*224*224, feature: 32*196*1536
s2wrapper.forward(
    model,
    input,
    scales=None,
    img_sizes=None,
    max_split_size=None,
    resize_output_to_idx=0,
    num_prefix_token=0,
    output_shape='bnc',
)

model: Your vision model or any function that takes in BxCxHxW image tensor and outputs BxNxC feature tensor.

input: Input image tensor with shape BxCxHxW.

scales: A list of scales to extract features on. For example, scales=[1, 2] will extract feature on 2242 and 4482 scales if default size is 2242.

img_sizes: Alternatively, instead of assigning scales, you can assign the image size for each scale. For example, img_sizes=[224, 448] will yeild with same results as scales=[1, 2] for default size of 2242.

max_split_size: The maximum size of sub-images splitted from the large image. For each scale, the image will be splitted into ceil(img_size_that_scale / max_split_size)**2 sub-images. If None, set by default as the size of input.

resize_output_to_idx: Which scale to resize the final feature map to. Default is the first scale in scales or img_sizes.

num_prefix_token: Number of prefix tokens in the feature map. For example, if the feature map returned by model contains 1 [CLS] token and other spatial tokens, set num_prefix_token=1. Default is 0.

output_shape: Shape of the output features. Need to be either bnc (e.g., ViT) or bchw (e.g., ConvNet). Default is bnc.

같은 계산량에서 성능이 좋아진다고?

image

현준님 정리

GPT-4V, LLaVA-1.6 등 최신 MLLM에서 많이 적용해보고 있는, 기존 224로 학습된 Visual Encoder에 Large Scale Image를 Crop하여 넣는 방법을 잘 정리한 연구입니다. 제안하는 방법 (Scaling on Scales (S^2))을 사용하는 경우, 다양한 Vision Task (Classification, Segmentation, Depth Estimation, MLLM Benchmarks, Robotic Manipulation)에서 작은 사이즈 모델 (ViT-B/L)이 큰 모델 (ViT-H/G)의 성능을 Outperform 할 수 있다고 합니다 (V* Benchmarks의 경우 LLaVA-1.5에 S^2를 적용하면 GPT-4V, Gemini Pro 성능보다 우위).

논문 결론 및 나의 생각

hjeun commented 3 months ago
#  ------------------------------------------------------------------------------------------
#  Copyright (c) 2024 Baifeng Shi.
#  All rights reserved.
#
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

import math
import torch
import torch.nn.functional as F
from einops import rearrange
from .utils import split_chessboard, merge_chessboard

# 448 ViT-L을 사용해서, 1344 Image Size에 적용하는 경우
# Input은 448 (작은 사이즈)로 들어가는 것을 Default로 봄.
# scales=None
# image_sizes = [448, 1344]
# max_split_size = 448
# resize_output_to_idx=0
# num_prefix_token=0 (현재 모델 [CLS] 없을건데..?)
# output_shape='bnc'

1. ViT에 Image 넣기 전, Image Split까지 image

def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0, output_shape='bnc'):

    assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
    assert input.shape[2] == input.shape[3], "Currently only square images are supported."
    assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
    assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."

    # square size만 지원해서 h(=w=input_size)만 가지고 시작
    b, c, input_size, _ = input.shape

    # image size for each scale 
    # img_sizes = [448, 1344]
    assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
    img_sizes = img_sizes or [int(input_size * scale) for scale in scales]

    # prepare multiscale inputs
    # num_split=[1,3]=[448, 1344(=size)/448(=max_split_size)]
    max_split_size = max_split_size or input_size   # The maximum size of each split of image. Set as the input size by default
    num_splits = [math.ceil(size / max_split_size) for size in img_sizes]   # number of splits each scale
    input_multiscale = []
    for size, num_split in zip(img_sizes, num_splits):
        # 448, 1344로 resize부터 하고
        x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
        # chessboard 형태로 split하는데, num_split은 1, 3이니까 squre_root 값인듯
        x = split_chessboard(x, num_split=num_split)
        # x는 1개, 9개 나옴 (예상)
        input_multiscale.append(x)

2. global + split한 이미지들 ViT에 넣어 Feature 생성 image

    # run feedforward on each scale, ViT에 넣어 model.foward(x)
    outs_multiscale = [model(x) for x in input_multiscale]
    # [CLS] 있는 경우 고려해서 처리 (없으니까 무시)
    if num_prefix_token > 0:
        outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
        outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
    if output_shape == 'bnc':
        outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
                           for out in outs_multiscale]

3. split되어 나온 feature grid에 맞게 merge image

    # merge outputs of different splits for each scale separately
    # 
    outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]

4. merge한 feature를 global feature (448) 사이즈로 pooling 후 concat image

    # interpolate outputs from different scales and concat together
    output_size = outs_multiscale[resize_output_to_idx].shape[-2]
    # interpolate 함수로 pooling (mode='area'), global feature size로 pooling
    out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
                                   mode='area').to(outs_multiscale[i].dtype)
                     for i in range(len(outs_multiscale))], dim=1)
    if output_shape == 'bnc':
        out = rearrange(out, 'b c h w -> b (h w) c')
    if num_prefix_token > 0:
        # take the mean of prefix tokens from different splits for each scale
        outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
        out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
        out = torch.cat([out_prefix_multiscale, out], dim=1)

    return out
hjeun commented 3 months ago