이 논문의 목적은 promptable하게 학습되어서 새로운 이미지나 task에 zero-shot transfer가 가능한 foundation model을 만드는 것임.
확인하며 읽어야 할 점들
What task will enable zero-shot generalization?
What is the corresponding model architecture?
What data can power this task and model?
Segment Anything Task
본 논문에서는 promptable segmentation task를 통해 powerful pretraining objective로 학습하고 넓은 범위의 downstream applications를 가능하게 함.다양하고 큰 스의 소스 데이터가 필요로 하기 때문에 새로 수집함.
본 논문의 task는 어떤 segmentation prompt를 주든 valid segmentation mask를 반환하는 것임.
NLP에서는 next token prediction 방식으로 foundation model을 pre-train 시킴. 이 아이디어에 착안함.
prompt가 어떤 이미지의 어떤 부분을 segment할 것인지를 특정지어주는 역할을 함. (object를 식별할 수 있는 spatial 정보나, text 정보를 prompt가 포함)
prompt는 다양한 형태가 가능한데, foreground/background points, rough box or mask, free-form text, 또는 이미지에서 뭘 segment해야 하는지에 대한 어떤 information이든 다 입력으로 받을 수 있음.
어떤 prompt를 입력받던 (여러 object에 해당할 수 있는ambiguous prompt 등) 최소 하나의 object에 대해 reasonable한 valid segmentation mask를 return한다. (마치 NLP에서 language model이 ambiguous prompt에 대해서 일관성있는 결과를 기대하는 것과 유사함)
valid segmentation mask를 반환해야 함. valid segmentation mask의 의미는 prompt가 모호하거나 object가 여러개 있는 상황에서도 최소 하나의 object에 대해 reasonable mask를 반환해야 함.
Promptable segmentation task는 pre-training objective가 됨과 동시에 general downstream segmentation task를 풀기도 함.
위 fig3이 ambiguous prompt 예시를 보여줌. 같은 포인트여도 셔츠를 의미하거나 사람을 의미하는 등 그 의미가 다를 수 있음 .
Pre-training
promptable segmentation task는 각 train sample에 대해 sequence of prompts(ex. points, boxes, masks, Free-form Text 등)를 시뮬레이트 하는 natural pre-training 알고리즘을 제안함.
모델의 예측을 GT와 비교함. 이 방식은 interactive segmentation이랑 유사하지만 차이가 있다면 interactive segmentation은 충분한 사용자 입력 후에 valid mask를 최종 예측 하지만 본 논문의 경우 프롬포트가 모호한 경우에도 언제든지 valid mask를 예측하는 것이다.
Ambiguous한 prompt에 대해서도 학습을 했기 때문에 어떤 prompt가 와도 valid 한 mask를 생성할 수 있었음..
Zero-shot transfer
pre-training task는 inference시에 어떤 프롬포트든 적절히 응답할 수 있다.
예를 들어 고양이를 위한 bounding box detector가 있다면 이 detector의 box output이 prompt가 되어서 instance segmentation 수행함.
Prompting과 Composition은 단일 모델이 설계 당시에는 알 수 없는 task를 수행할 수 있다는 장점이 있음. (ex. CLIP의 경우 DALL-E 이미지 생성 시스템의 image-text align component임)
고정된 task를 위해 학습된 시스템보다 더 다양한 응용을 가능케 함.
Segment Anything model
Image encoder, Prompt encoder, Mask decoder 영역으로 나누어짐 (이 때 Prompt encoder, Mask decoder를 묶어서 Prompt 모델이라고 부름)
위와 같이 세 가지 모듈로 구성함으로써 같은 이미지 임베딩이더라도 서로 다른 prompt로 재사용 될 수 있다.
Image encoder
MAE pre-trained ViT로 embedding 뽑음.
Prompt encoder
Sparse prompt(points, box, text), dense prompt(mask) 를 입력
Sparse prompt는 256차원 임베딩으로 매핑
Point는 point location과 학습 가능한 foreground/background embedding 중 하나와의 합
Box: Top-left corner, bottom-right corner에 대한 learned embedding의 합
Text: CLIP text embedding
Mask decoder
Image embedding과 prompt embedding, output token embedding을 입력받아서 mask로 매핑하는 역할을 함. self-attention, cross-attention을 통해 mask prediction 함.
Token에 대한 Self-attention
Token을 Query로 하여 Image Embedding에 Cross-Attention
Point-wise MLP로 각 Token을 업데이트
Image Embedding을 Query로 하여 Token에 Cross-Attention
Attention layer에 입력할 때마다 image embedding에 positional encoding을 더해주며, attention layer 입력마다 updated token에 original token prompt 전체를 합산함.
Decoder 수행하고 나서 Transposed convolution으로 이미지 임베딩 크기 4배로 upsampling하고 token을 query로 해서 한번 더 이미지 임베딩에 어텐션 한 다음, output token embedding에 3개의 MLP 적용
Resolving Ambiguity
Mask ambiguity란 같은 점이라도 서로 다른 마스크가 의도될 수 있음을 의미함.
따라서 mask output이 여러개로, multiple output masks를 예측함. 3개가 most common cases를 다루기 충분한 것을 발견함. (하나이게 되면 ambiguous prompt에 대해 multiple valid mask를 평균내게 될 것임)
각 마스크마다 confidence score(ex. estimated IoU)로 랭킹 매기는 식.
그럼 back propagation 어떻게하지? → 마스크 들 중 minimum loss인 것만 수행
loss: focal loss, dice loss의 linear combination으로 학습함.
focal loss: 더 어려운 객체에 가중치를 줌
Dice loss: IoU보다 GT를 얼마나 많이 포함했는가(Recall에 집중)
efficiency: prompt encoder + mask decoder 수행 시간이 50ms 이내여서 real-time interactive prompting 가능한 수준
Segment Anything Data Engine
크고 다양한 셋의 마스크로 학습해야 하는데 기존 foundation model들은 데이터를 online으로 수집했기 때문에 mask가 자연스럽게 풍부하지 않음.
data engine을 통해 model-in-the-loop dataset annotation을 co-develop함.
assisted-manual stage (SA-1B dataset 구축 단계): Annotator + seg tool
Segment Anything Task
본 논문에서는 promptable segmentation task를 통해 powerful pretraining objective로 학습하고 넓은 범위의 downstream applications를 가능하게 함.다양하고 큰 스의 소스 데이터가 필요로 하기 때문에 새로 수집함.
본 논문의 task는 어떤 segmentation prompt를 주든 valid segmentation mask를 반환하는 것임.
위 fig3이 ambiguous prompt 예시를 보여줌. 같은 포인트여도 셔츠를 의미하거나 사람을 의미하는 등 그 의미가 다를 수 있음 .
Pre-training
Zero-shot transfer
Segment Anything model
Image encoder, Prompt encoder, Mask decoder 영역으로 나누어짐 (이 때 Prompt encoder, Mask decoder를 묶어서 Prompt 모델이라고 부름)
위와 같이 세 가지 모듈로 구성함으로써 같은 이미지 임베딩이더라도 서로 다른 prompt로 재사용 될 수 있다.
Mask decoder
Token에 대한 Self-attention
Token을 Query로 하여 Image Embedding에 Cross-Attention
Point-wise MLP로 각 Token을 업데이트
Image Embedding을 Query로 하여 Token에 Cross-Attention
Resolving Ambiguity
loss: focal loss, dice loss의 linear combination으로 학습함.
efficiency: prompt encoder + mask decoder 수행 시간이 50ms 이내여서 real-time interactive prompting 가능한 수준
Segment Anything Data Engine
data engine을 통해 model-in-the-loop dataset annotation을 co-develop함.
2) semi-automatic: 사람이 라벨링한 GT + 모델이 예측한 mask
3) fully automatic : 모델이 예측한 mask로만 학습
Experiments
Single point valid mask evaluation
Zero-Shot Edge Detection
Zero-shot object proposals
Zero-shot instance segmentation
Zero-Shot Text-to-Mask
학습은 CLIP image embedding을 prompt로 주고, inference시에서는 text prompt를 사용하는데 간단한 prompt에 대해 좋은 성능을 보임.
Text만 줬을때 틀렸던 것들은 point를 같이 주면 잘함.