## About
The Segment Anything Model (SAM) has demonstrated remarkable zero-shot capability and flexible geometric prompting in general image segmentation. However, it often struggles in domains that are either sparsely represented or lie outside its training distribution, such as aerial, medical, and non-RGB images. Recent efforts have predominantly focused on adapting SAM to these domains using fully supervised methods, which necessitate large amounts of annotated training data and pose practical challenges in data collection. This paper presents CAT-SAM, a ConditionAl Tuning network that explores _few-shot adaptation_ of SAM toward various challenging downstream domains in a data-efficient manner. The core design is a _prompt bridge_ structure that enables _decoder-conditioned joint tuning_ of the heavyweight image encoder and the lightweight mask decoder. The bridging maps the domain-specific features of the mask decoder to the image encoder, fostering synergic adaptation of both components with mutual benefits with few-shot target samples only, ultimately leading to superior segmentation in various downstream tasks. We develop two CAT-SAM variants that adopt two tuning strategies for the image encoder: one injecting learnable prompt tokens in the input space and the other inserting lightweight adapter networks. Extensive experiments over 11 downstream tasks show that CAT-SAM achieves superior segmentation consistently even under the very challenging one-shot adaptation setup.
## News
- **(2024/9)** We updated the inference code to support customized model inference.
- **(2024/8)** CAT-SAM is selected as an :sparkles: Oral Paper :sparkles: at ECCV 2024!
- **(2024/7)** We released the training code. Thank you for your waiting!
- **(2024/7)** CAT-SAM is accepted by ECCV 2024! See you in Milano!
## Method
![overall_pipeline](./figs/CAT-SAM.png "overall_pipeline")
## Results
### 1-Shot Adaptation:
| Methods | WHU | Kvasir | SBU-Shadow | Average |
|:-----------------:|:-----:|:------:|:----------:|:-------:|
| **SAM (baseline)** | 43.5 | 79.0 | 62.4 | 61.6 |
| **VPT-shallow** | 60.8 | 79.8 | 68.7 | 69.8 |
| **VPT-deep** | 57.8 | 80.4 | 76.0 | 71.4 |
| **AdaptFormer** | 83.2 | 76.8 | 77.2 | 79.1 |
| **LoRA** | 86.1 | 77.5 | 74.4 | 79.3 |
| **CAT-SAM-T (Ours)** | 86.8 | 83.4 | 78.0 | 82.7 |
| **CAT-SAM-A (Ours)** | 88.2 | 85.4 | 81.9 | 85.2 |
## Installation
Please clone our project to your local machine and prepare our environment by the following commands:
```
$: cd cat-sam
$: conda create -n cat-sam python=3.9
$: conda activate cat-sam
$: python -m pip install -r requirements.txt
(cat-sam) $: python -m pip install -e .
```
The code has been tested on A100/A6000/V100 with Python 3.9, CUDA 11.7 and Pytorch 1.13.1. Any other devices and environments may require to update the code for compatibility.
## Data
Please refer to the README.md in the dataset-specific folders under `./data` to prepare each of them.
## Train
Before training, please download the SAM checkpoints to `./pretrained` from [here](https://github.com/facebookresearch/segment-anything#model-checkpoints).
After downloading, there should be:
```
cat-sam/
pretrained/
sam_vit_b_01ec64.pth
sam_vit_h_4b8939.pth
sam_vit_l_0b3195.pth
...
```
For one-shot training, please run:
```
$: cd cat-sam
$: pwd
/your_dir/cat-sam
$: conda activate cat-sam
(cat-sam) $: python train.py --dataset