fabio-sim / Depth-Anything-ONNX

ONNX-compatible Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data
Apache License 2.0
234 stars 23 forks source link

Metric depth inference #4

Closed KevinCain closed 1 month ago

KevinCain commented 7 months ago

I'd like to try inference with a model trained on metric data. I've exported the DA NYU indoor pretrained model trained on metric indoor data to ONNX and run inference, as shown below, but the results indicate I'm doing something wrong.

I adapted the export script provided (see below) then invoke the script for ‘base’ model export via python .\export_metric.py --model_type "metric" --model b.

The export process seemed to go fine, and checking the model didn't turn up any errors:

import onnx
from onnx import checker
model = onnx.load("depth_anything_vitb14_metric.onnx")
checker.check_model(model)

I can launch inference from a variant of the supplied eval script (also below). Using one of the provided .onnx models works as expected, e.g.: python .\batch_infer.py --img_dir .\input --model .\weights\depth_anything_vitb14.onnx --output_dir .\output

Below is an input image and output relative depth map:

IMG_20240111_175231

IMG_20240111_175231

I invoke inference with my exported model as follows: python .\batch_infer.py --img_dir .\input --model .\weights\depth_anything_vitb14_metric.onnx --output_dir .\output_metric

However, the depth results show tell-tale repeating block artifacts below: image

'export_metric.py':

import argparse
import torch
from onnx import load_model, save_model
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

from depth_anything.dpt import DPT_DINOv2
from depth_anything.util.transform import load_image

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        choices=["s", "b", "l"],
        required=True,
        help="Model size variant. Available options: 's', 'b', 'l'.",
    )
    parser.add_argument(
        "--model_type",
        type=str,
        choices=["metric", "relative"],
        required=True,
        help="Model type. Available options: 'metric', 'relative'.",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        required=False,
        help="Path to save the ONNX model.",
    )

    return parser.parse_args()

def export_onnx(model: str, model_type: str, output: str = None):
    # Handle args
    if output is None:
        output = f"weights/depth_anything_vit{model}14_{model_type}.onnx"

    # Device for tracing
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Sample image for tracing
    #image, _ = load_image("assets/sacre_coeur1.jpg")
    image, _ = load_image("assets/grace.jpg")
    image = torch.from_numpy(image).to(device)

    # Initialize model instance based on model size
    if model == "s":
        depth_anything = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
    elif model == "b":
        depth_anything = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
    else:  # model == "l"
        depth_anything = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])

    # Load checkpoint
    if model_type == "metric":
        checkpoint = torch.load("./metric_source/depth_anything_metric_depth_indoor.pt", map_location="cpu")
    else:  # model_type == "relative"
        checkpoint = torch.hub.load_state_dict_from_url(
            f"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vit{model}14.pth",
            map_location="cpu"
        )

    # Extract model weights from checkpoint
    if "model" in checkpoint:
        state_dict = checkpoint["model"]
    else:
        state_dict = checkpoint

    # Load state dict into model
    depth_anything.to(device).load_state_dict(state_dict, strict=False)  # Consider using strict=False if necessary
    depth_anything.eval()

    # Proceed with ONNX export as before
    torch.onnx.export(
        depth_anything,
        image,
        output,
        input_names=["image"],
        output_names=["depth"],
        opset_version=17,
        dynamic_axes={
            "image": {2: "height", 3: "width"},
            "depth": {2: "height", 3: "width"},
        },
    )

    # Shape inference for ONNX model
    save_model(
        SymbolicShapeInference.infer_shapes(load_model(output), auto_merge=True),
        output,
    )

if __name__ == "__main__":
    args = parse_args()
    export_onnx(**vars(args))

'batch_infer.py':

import argparse
import os
import time
import cv2
import numpy as np
import onnxruntime as ort

from depth_anything.util.transform import load_image

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--img_dir",
        type=str,
        required=True,
        help="Path to input image directory.",
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Path to ONNX model.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory to save output depth images.",
    )
    parser.add_argument(
        "--viz", action="store_true", help="Whether to visualize the results."
    )
    return parser.parse_args()

def infer(image_path: str, session, output_dir: str, viz: bool = False):
    start_time = time.time()

    image, (orig_h, orig_w) = load_image(image_path)

    depth = session.run(None, {"image": image})[0]

    depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
    depth_scaled = np.clip(depth * (65535 / depth.max()), 0, 65535).astype(np.uint16)

    # Example for PNG format
    output_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + ".png")

    # Save grayscale depth image (without color map)
    cv2.imwrite(output_path, depth_scaled)

    end_time = time.time()
    processing_time = end_time - start_time
    print(f"Processed {image_path} in {processing_time:.2f} seconds")

    return processing_time

def main():
    args = parse_args()

    # Ensure output directory exists
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Start model loading time measurement
    start_model_loading = time.time()

    # Load model
    session = ort.InferenceSession(
        args.model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )

    # End model loading time measurement
    end_model_loading = time.time()
    model_loading_time = end_model_loading - start_model_loading
    print(f"Model loaded in {model_loading_time:.2f} seconds")

    # Initialize variables for timing
    total_time_excluding_model_loading = 0
    image_count = 0

    # Process each image in the directory
    for filename in os.listdir(args.img_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(args.img_dir, filename)
            processing_time = infer(image_path, session, args.output_dir, args.viz)
            total_time_excluding_model_loading += processing_time
            image_count += 1

    if image_count > 0:
        average_time_excluding_model = total_time_excluding_model_loading / image_count
        average_time_including_model = (total_time_excluding_model_loading + model_loading_time) / image_count
        print(f"Average processing time per image (excluding model loading): {average_time_excluding_model:.2f} seconds")
        print(f"Average processing time per image (including model loading): {average_time_including_model:.2f} seconds")
    else:
        print("No images were processed.")

if __name__ == "__main__":
    main()
timmh commented 5 months ago

For anyone still interested in this, I've implemented working ONNX export for metric depth models here: https://github.com/timmh/Depth-Anything