facebookresearch / segment-anything-2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
10.76k stars 874 forks source link

Unable to initialize SAM2 model due to Hydra configuration error #304

Closed seanbetts closed 5 days ago

seanbetts commented 5 days ago

Description

I'm trying to set up a FastAPI microservice to use the SAM2 model. However, I'm encountering persistent errors when trying to initialize the model, specifically related to Hydra configuration loading.

Environment

Steps to Reproduce

  1. Clone the SAM2 repository
  2. Set up a FastAPI application (code attached below)
  3. Create a Dockerfile (attached below)
  4. Build and deploy the Docker image to Google Cloud Run

Error Message

ERROR:sam2_microservice:Failed to initialize SAM2 models: 'dict' object has no attribute 'endswith' ERROR:sam2_microservice:Traceback: Traceback (most recent call last): ... File "/opt/conda/lib/python3.10/site-packages/hydra/plugins/config_source.py", line 128, in if not any(filename.endswith(ext) for ext in supported_extensions): AttributeError: 'dict' object has no attribute 'endswith'

Code

sam2_microservice.py:

import os
import yaml
import sys
import io
import json
import logging
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse
from google.cloud import storage
from google.oauth2 import service_account
from PIL import Image
import torch
import tempfile
import numpy as np
import cv2
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf

from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

SAM2_REPO_ROOT = "/app/segment-anything-2"  # adjust if necessary
sys.path.append(SAM2_REPO_ROOT)
os.environ["PYTHONPATH"] = f"{SAM2_REPO_ROOT}:{os.environ.get('PYTHONPATH', '')}"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("Starting FastAPI server...")

app = FastAPI()

try:
    credentials = service_account.Credentials.from_service_account_file('keyfile.json')
    storage_client = storage.Client(credentials=credentials)
except Exception as e:
    logger.error(f"Failed to initialize Google Cloud Storage client: {str(e)}")
    raise

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = "/app/checkpoints/sam2_hiera_large.pt"
config = "sam2_hiera_l.yaml"

sam2_model = None
image_predictor = None
mask_generator = None
video_predictor = None

@app.on_event("startup")
async def startup_event():
    global sam2_model, image_predictor, mask_generator, video_predictor
    try:
        config_dir = "sam2_configs"
        config_name = "sam2_hiera_l.yaml"

        logger.info(f"Current working directory: {os.getcwd()}")
        logger.info(f"Config directory: {config_dir}")
        logger.info(f"Config name: {config_name}")
        logger.info(f"Config file exists: {os.path.exists(os.path.join(config_dir, config_name))}")
        logger.info(f"Contents of {config_dir}: {os.listdir(config_dir)}")

        logger.info("Clearing existing Hydra instance...")
        GlobalHydra.instance().clear()

        logger.info("Reinitializing Hydra with the correct config path...")
        with initialize(version_base="1.2", config_path=config_dir):
            logger.info("Composing Hydra configuration...")
            with open("/app/sam2_configs/sam2_hiera_l.yaml", 'r') as file:
                cfg = yaml.safe_load(file)

            logger.info(f"Hydra config search path: {GlobalHydra.instance().config_loader().get_sources()}")
            logger.info(f"Config (cfg) type: {type(cfg)}")
            logger.info(f"Loaded config: {OmegaConf.to_yaml(cfg)}")

            if cfg is None:
                logger.warning("Hydra failed to load config. Attempting direct YAML load.")
                with open(os.path.join(config_dir, config_name), 'r') as f:
                    cfg = yaml.safe_load(f)
                cfg = OmegaConf.create(cfg)

            logger.info("Initializing SAM2 model...")
            sam2_model = build_sam2(cfg, checkpoint, device=device, apply_postprocessing=False)

        logger.info("Initializing SAM2 image predictor...")
        image_predictor = SAM2ImagePredictor(sam2_model)

        logger.info("Initializing SAM2 automatic mask generator...")
        mask_generator = SAM2AutomaticMaskGenerator(sam2_model)

        logger.info("Initializing SAM2 video predictor...")
        video_predictor = build_sam2_video_predictor(cfg, checkpoint)

        logger.info("SAM2 models initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize SAM2 models: {str(e)}")
        logger.exception("Traceback:")
        raise

@app.get("/health")
async def health_check():
    logger.debug("Health check called")
    return {"status": "sam2 backend ok"}

@app.post("/segment")
async def segment_image(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        image_rgb = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)

        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
            image_predictor.set_image(image_rgb)
            masks, _, _ = image_predictor.predict()

        masks_list = masks.tolist()

        return JSONResponse(content={"masks": masks_list})
    except Exception as e:
        logger.error(f"Error in segment_image: {str(e)}")
        return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/segment_video")
async def segment_video(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
            temp_video.write(await file.read())
            temp_video_path = temp_video.name

        background_tasks.add_task(process_video, temp_video_path)

        return JSONResponse(content={"message": "Video processing started"})
    except Exception as e:
        logger.error(f"Error in segment_video: {str(e)}")
        return JSONResponse(status_code=500, content={"error": str(e)})

async def process_video(video_path: str):
    try:
        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
            state = video_predictor.init_state(video_path)

            results = []
            for frame_idx, object_ids, mask_logits in video_predictor.propagate_in_video(state):
                results.append({
                    "frame_idx": frame_idx,
                    "object_ids": object_ids.tolist(),
                    "masks": mask_logits.cpu().numpy().tolist()
                })

        with open(f"{os.path.splitext(video_path)[0]}_sam2_result.json", "w") as f:
            json.dump(results, f)

    except Exception as e:
        logger.error(f"Error processing video: {str(e)}")
    finally:
        os.unlink(video_path)

if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 8080))
    uvicorn.run(app, host="0.0.0.0", port=port)

Dockerfile:

# Use an official PyTorch image as the base image
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime

# Set the working directory in the container
WORKDIR /app

# Install git and other necessary tools
RUN apt-get update && apt-get install -y \
    git \
    wget \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip

# Copy the configuration file and microservice code into the container
COPY sam2_configs/sam2_hiera_l.yaml /app/sam2_configs/
RUN touch /app/sam2_configs/__init__.py
COPY sam2_microservice.py /app/

# Clone the SAM 2 repository
RUN git clone https://github.com/facebookresearch/segment-anything-2.git /app/segment-anything-2

# Install SAM 2
RUN cd /app/segment-anything-2 && pip install -e .

# Install additional dependencies
COPY requirements.txt .
COPY keyfile.json .
RUN pip install --no-cache-dir -r requirements.txt

# Download the SAM 2 model checkpoint
RUN mkdir -p /app/checkpoints
RUN wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P /app/checkpoints/

# Set the PYTHONPATH to include the SAM2 directory
ENV PYTHONPATH="${PYTHONPATH}:/app/segment-anything-2"

# Expose the port the app runs on
EXPOSE 8080

# Command to run the application
CMD ["uvicorn", "sam2_microservice:app", "--host", "0.0.0.0", "--port", "8080"]

Additional Context

Questions

Any assistance or guidance would be greatly appreciated.

Thank you!

seanbetts commented 5 days ago

FIXED! For anyone else struggling with these issues make sure you are doing the following: