dhkim0225 / 1day_1paper

read 1 paper everyday (only weekday)
54 stars 1 forks source link

[95] CoCa: Contrastive Captioners are Image-Text Foundation Models #125

Open dhkim0225 opened 2 years ago

dhkim0225 commented 2 years ago

paper

imagenet sota 를 찍으며 큰 화제가 되었던 녀석.

한 번의 pretraining 으로 다양한 task를 풀어낸다. contrastive loss 와 captioning loss 를 동시에 사용하는 형태.

image

Pretraining

image image attentional pooling 이라는 attention 을 활용한 pooling 방식을 활용. [batch, seq_len, dim][batch, 1, dim] 형태로 바꿔줘서 contrastive loss 에도 적용하고, (그림의 노란색) [batch, seq_len, dim][batch, N, dim] 형태로 바꿔줘서 caption loss 에도 적용한다. (그림의 파란색)

attentional pooling 은 1개의 multi-head attention layer (n_query learnable param) 형태이다. task 별로 개수를 다르게 학습시킨다. (caption 을 위해서는 n_query==256, contrastive loss 를 위해서는 n_query==1)

uni-modal 에는 text-only data 가 통과하는 형태이며, multi-modal 에서는 image-text 의 feature 가 함께 혼합되는 형태이다. uni-modal, multimodal decoder 레이어 수는 동일하게 설정.

pretraining loss 는 다음과 같다. image image image captioning loss 는 일반적인 generation model loss 와 동일하고, contrastive loss 또한 일반적인 형태처럼, batch 내에서 positive, negative 를 추출하는 형태이다. i, j 는 각각 i번째 j번째 image-text pair 를 의미한다.

video 의 경우에도 attentional pooler 로 대응 가능함. 프레임 별로 encoder 를 통과시키고, pooling 시켜주면 됨. image

Data

JFT-3B (with label names), ALIGN dataset JFT-3B 학습시킬 때는, BASIC 모델 처럼 "a photo of" 접두사를 붙여서 사용함 tokenizer 는 sentencepeice 모델을 pretraining data 에 대해 64K 학습시킴.

ETC

lingvo fromework (TF) + GSPMD batch size 65536 500,000 step (5ep for JFT, 10ep for ALIGN) 2048 TPUv4 로 5일 caption loss lambda == 2.0 contrastive loss lambda == 1.0 CLIP, ALIGN, FLORENCE 처럼 마지막 epoch 은 576x576 으로 학습

Results.

zero-shot 수행 시, CLIP 처럼 최대한 pretraining 에 downstream 데이터를 안 섞었음 (deduplication) frozen 은 encoder 에 대해 수행 후, attentional pooler 부분을 학습시킨 형태

attentional pooler 가 linear probing 보다 practical 한 방식이라고 주장.

image image image image image image image image image

Ablations

image BOLD 체가 coca의 default setting 임

ZS == zero-shot LE == Linear Evaluation AE == Attentional Evaluation (with frozen-encoder) FT == Finetuning

(f) 의 경우 parallel 디자인이 위 모델 설명 형태. cascade 디자인은 256 개의 generative token 들에다가 다시 attention 을 가하는 형태. (최종 형태) n_query == 0 의 의미는, generative pooler 를 사용하지 않았을 때. (즉, 모든 vit token을 cross attention에다 박는 형태)

Hyperparms

image image image