AnjieCheng / SpatialRGPT

[NeurIPS'24] This repository is the implementation of "SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models"
https://www.anjiecheng.me/SpatialRGPT
Apache License 2.0
67 stars 5 forks source link
spatial-reasoning vision-language-model

SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models (NeurIPS'24)

Code License Model License Data License Benchmark License Python 3.10+

arxiv / Huggingface


πŸ’‘ Introduction

SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models
An-Chieh Cheng, Hongxu (Danny) Yin, Yang Fu, Qiushan Guo, Ruihan Yang, Jan Kautz, Xiaolong Wang, Sifei Liu

SpatialRGPT is a powerful vision-language model adept at understanding both 2D and 3D spatial arrangements. It can process any region proposal, such as boxes or masks, and provide answers to complex spatial reasoning questions.


πŸ“’ News


Installation

To build environment for training SpatialRGPT, please run the following:

./environment_setup.sh srgpt
conda activate srgpt

Gradio Demo

To run the Gradio demo for SpatialRGPT, please follow these steps. Due to pydantic version conflicts, the demo environment is not compatible with the training environment. Therefore, a separate environment will need to be created for the Gradio demo.

  1. Build the environment.

    ./environment_setup.sh srgpt-demo
    conda activate srgpt-demo
    pip install gradio==4.27 deepspeed==0.13.0 gradio_box_promptable_image segment_anything_hq
    pip install -U 'git+https://github.com/facebookresearch/detectron2.git@ff53992b1985b63bd3262b5a36167098e3dada02'

    If you run into an error with the detectron2 installation, it could be because CUDA_HOME is not set. To fix this, export CUDA_HOME to your local CUDA path. See details in this issue.

  2. Clone the Depth-Anything repository and download the necessary checkpoint:

    git clone https://github.com/LiheYoung/Depth-Anything.git
    wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth

    Place depth_anything_vitl14.pth under Depth-Anything/checkpoints, and set the path to the environment variable. For example:

    export DEPTH_ANYTHING_PATH=/YOUR_OWN_PATH/Depth-Anything
  3. Download SAM-HQ checkpoint from here, and set the path to the environment variable. For example:

    export SAM_CKPT_PATH=/YOUR_OWN_PATH/sam_hq_vit_h.pth
  4. Launch Gradio server. You can use your own checkpoint, or use a8cheng/SpatialRGPT-VILA1.5-8B

    cd demo
    python gradio_web_server_multi.py --model-path PATH_TO_CHECKPOINT

Training

SpatialRGPT follows VILA training, which contains three steps. We provide training script for three different LLM models, sheared_3b, llama2_7b, llama3_8b. You can find the training scripts for each stage in the scripts/srgpt folder.


Open Spatial Dataset

Please download the Open Spatial Dataset from huggingface, and modify the path in llava/data/dataset_mixture.py.

For raw images, please download OpenImages from OpenImagesV7. To process the rgb images into depth, we use DepthAnythingV2 and save the depth with the following function:

Click to expand ```python def save_raw_16bit(depth, fpath, height, width): depth = F.interpolate(depth[None, None], (height, width), mode='bilinear', align_corners=False)[0, 0] depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.cpu().numpy().astype(np.uint8) colorized_depth = np.stack([depth, depth, depth], axis=-1) depth_image = Image.fromarray(colorized_depth) depth_image.save(fpath) ```

Dataset Synthesis Pipeline

We've also made the dataset synthesis pipeline available. You can find the code and instructions in the dataset_pipeline folder. Please note that some of the packages we use have had version updates, and we've migrated to their latest versions. This may result in some bugs. Feel free to report any issues or unexpected results you encounter.

Wis3D Demo


Evaluations

Our evaluation scripts takes the following arguments, PATH_TO_CKPT, CKPT_NAME, CONV_TYPE.

Region Classification

First, prepare the evaluation annotation following RegionCLIP. Then, use scripts/srgpt/eval/coco_cls.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE.

SpatialRGPT-Bench Evaluation

First, download the images from omni3d, following there instructions. Then download annotations from https://huggingface.co/datasets/a8cheng/SpatialRGPT-Bench. Modify the path in scripts/srgpt/eval/srgpt_bench.sh to corresponding paths.

Note that for SpatialRGPT-Bench, you need to clone the Depth-Anything repository and download the necessary checkpoint:

git clone https://github.com/LiheYoung/Depth-Anything.git
wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth

Place depth_anything_vitl14.pth under Depth-Anything/checkpoints, and set the path to the environment variable.

export DEPTH_ANYTHING_PATH="PATH_TO_DEPTHANYTHING"

Then use scripts/srgpt/eval/srgpt_bench.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE.

General VLM Benchmarks

Our code is compatible with VILA's evaluation scripts. See VILA/evaluations for details.


πŸ“œ Citation

  @inproceedings{cheng2024spatialrgpt,
          title={SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models},
          author={Cheng, An-Chieh and Yin, Hongxu and Fu, Yang and Guo, Qiushan and Yang, Ruihan and Kautz, Jan and Wang, Xiaolong and Liu, Sifei},
          booktitle={NeurIPS},
          year={2024}
  }

πŸ™ Acknowledgement

We have used code snippets from different repositories, especially from: VILA, Omni3D, GLaMM, VQASynth, and ConceptGraphs. We would like to acknowledge and thank the authors of these repositories for their excellent work.