LiheYoung / Depth-Anything

[CVPR 2024] Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data. Foundation Model for Monocular Depth Estimation
https://depth-anything.github.io
Apache License 2.0
6.73k stars 516 forks source link

Export a fine-tuned model to onnx #111

Open Choi-YeongJoon opened 6 months ago

Choi-YeongJoon commented 6 months ago

I fine-tuned pretrained model(depth_anything_vits14.pth) using KITTI Data set, I could get finetuned model (depth_anything_finetune/ZoeDepthv1_05-Mar_14-45-9b59bd15407a_best.pt) And I confirmed that this infers the metric depth quite accurately! As a result, How can export this model to onnx file? I exported pretrained model(depth_anything_vits14.pth) before, but fine-tuned model, doesn't work.

import os
import torch
import torch.onnx

from depth_anything.dpt import DPT_DINOv2
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet

encoder = 'vits'
load_from = './checkpoints/fine_tune_small_test.pt'
# encoder = 'vitl'
# load_from = './checkpoints/ZoeDepthv1_05-Mar_14-45-9b59bd15407a_best.pt'
image_shape = (3, 518, 518)

# Initializing model
assert encoder in ['vits', 'vitb', 'vitl']
if encoder == 'vits':
    depth_anything = DPT_DINOv2(encoder='vits', features=64, out_channels=[48, 96, 192, 384], localhub='localhub')
elif encoder == 'vitb':
    depth_anything = DPT_DINOv2(encoder='vitb', features=128, out_channels=[96, 192, 384, 768], localhub='localhub')
else:
    depth_anything = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], localhub='localhub')

total_params = sum(param.numel() for param in depth_anything.parameters())
print('Total parameters: {:.2f}M'.format(total_params / 1e6))

# Loading model weight
depth_anything.load_state_dict(torch.load(load_from, map_location='cpu'), strict=True)

depth_anything.eval()

# Define dummy input data
dummy_input = torch.ones(image_shape).unsqueeze(0)

# Provide an example input to the model, this is necessary for exporting to ONNX
example_output = depth_anything(dummy_input)

onnx_path = load_from.split('/')[-1].split('.pth')[0] + '.onnx'

# Export the PyTorch model to ONNX format
torch.onnx.export(depth_anything, dummy_input, onnx_path, opset_version=11, input_names=["input"], output_names=["output"], verbose=True)

print(f"Model exported to {onnx_path}")

Error message is below :

RuntimeError: Error(s) in loading state_dict for DPT_DINOv2: Missing key(s) in state_dict: "pretrained.cls_token", "pretrained.pos_embed", "pretrained.mask_token", "pretrained.patch_embed.proj.weight", "pretrained.patch_embed.proj.bias", "pretrained.blocks.0.norm1.weight", "pretrained.blocks.0.norm1.bias", "pretrained.blocks.0.attn.qkv.weight", "pretrained.blocks.0.attn.qkv.bias", "pretrained.blocks.0.attn.proj.weight", "pretrained.blocks.0.attn.proj.bias", "pretrained.blocks.0.ls1.gamma", "pretrained.blocks.0.norm2.weight", "pretrained.blocks.0.norm2.bias", "pretrained.blocks.0.mlp.fc1.weight", "pretrained.blocks.0.mlp.fc1.bias", "pretrained.blocks.0.mlp.fc2.weight", "pretrained.blocks.0.mlp.fc2.bias", "pretrained.blocks.0.ls2.gamma", "pretrained.blocks.1.norm1.weight", "pretrained.blocks.1.norm1.bias", "pretrained.blocks.1.attn.qkv.weight", "pretrained.blocks.1.attn.qkv.bias", "pretrained.blocks.1.attn.proj.weight", "pretrained.blocks.1.attn.proj.bias", "pretrained.blocks.1.ls1.gamma", "pretrained.blocks.1.norm2.weight", "pretrained.blocks.1.norm2.bias", "pretrained.blocks.1.mlp.fc1.weight", "pretrained.blocks.1.mlp.fc1.bias", "pretrained.blocks.1.mlp.fc2.weight", "pretrained.blocks.1.mlp.fc2.bias", "pretrained.blocks.1.ls2.gamma", "pretrained.blocks.2.norm1.weight", "pretrained.blocks.2.norm1.bias", "pretrained.blocks.2.attn.qkv.weight", "pretrained.blocks.2.attn.qkv.bias", "pretrained.blocks.2.attn.proj.weight", "pretrained.blocks.2.attn.proj.bias", "pretrained.blocks.2.ls1.gamma", "pretrained.blocks.2.norm2.weight", "pretrained.blocks.2.norm2.bias", "pretrained.blocks.2.mlp.fc1.weight", "pretrained.blocks.2.mlp.fc1.bias", "pretrained.blocks.2.mlp.fc2.weight", "pretrained.blocks.2.mlp.fc2.bias", "pretrained.blocks.2.ls2.gamma", "pretrained.blocks.3.norm1.weight", "pretrained.blocks.3.norm1.bias", "pretrained.blocks.3.attn.qkv.weight", "pretrained.blocks.3.attn.qkv.bias", "pretrained.blocks.3.attn.proj.weight", "pretrained.blocks.3.attn.proj.bias", "pretrained.blocks.3.ls1.gamma", "pretrained.blocks.3.norm2.weight", "pretrained.blocks.3.norm2.bias", "pretrained.blocks.3.mlp.fc1.weight", "pretrained.blocks.3.mlp.fc1.bias", "pretrained.blocks.3.mlp.fc2.weight", "pretrained.blocks.3.mlp.fc2.bias", "pretrained.blocks.3.ls2.gamma", "pretrained.blocks.4.norm1.weight", "pretrained.blocks.4.norm1.bias", "pretrained.blocks.4.attn.qkv.weight", "pretrained.blocks.4.attn.qkv.bias", "pretrained.blocks.4.attn.proj.weight", "pretrained.blocks.4.attn.proj.bias", "pretrained.blocks.4.ls1.gamma", "pretrained.blocks.4.norm2.weight", "pretrained.blocks.4.norm2.bias", "pretrained.blocks.4.mlp.fc1.weight", "pretrained.blocks.4.mlp.fc1.bias", "pretrained.blocks.4.mlp.fc2.weight", "pretrained.blocks.4.mlp.fc2.bias", "pretrained.blocks.4.ls2.gamma", "pretrained.blocks.5.norm1.weight", "pretrained.blocks.5.norm1.bias", "pretrained.blocks.5.attn.qkv.weight", "pretrained.blocks.5.attn.qkv.bias", "pretrained.blocks.5.attn.proj.weight", "pretrained.blocks.5.attn.proj.bias", "pretrained.blocks.5.ls1.gamma", "pretrained.blocks.5.norm2.weight", "pretrained.blocks.5.norm2.bias", "pretrained.blocks.5.mlp.fc1.weight", "pretrained.blocks.5.mlp.fc1.bias", "pretrained.blocks.5.mlp.fc2.weight", "pretrained.blocks.5.mlp.fc2.bias", "pretrained.blocks.5.ls2.gamma", "pretrained.blocks.6.norm1.weight", "pretrained.blocks.6.norm1.bias", "pretrained.blocks.6.attn.qkv.weight", "pretrained.blocks.6.attn.qkv.bias", "pretrained.blocks.6.attn.proj.weight", "pretrained.blocks.6.attn.proj.bias", "pretrained.blocks.6.ls1.gamma", "pretrained.blocks.6.norm2.weight", "pretrained.blocks.6.norm2.bias", "pretrained.blocks.6.mlp.fc1.weight", "pretrained.blocks.6.mlp.fc1.bias", "pretrained.blocks.6.mlp.fc2.weight", "pretrained.blocks.6.mlp.fc2.bias", "pretrained.blocks.6.ls2.gamma", "pretrained.blocks.7.norm1.weight", "pretrained.blocks.7.norm1.bias", "pretrained.blocks.7.attn.qkv.weight", "pretrained.blocks.7.attn.qkv.bias", "pretrained.blocks.7.attn.proj.weight", "pretrained.blocks.7.attn.proj.bias", "pretrained.blocks.7.ls1.gamma", "pretrained.blocks.7.norm2.weight", "pretrained.blocks.7.norm2.bias", "pretrained.blocks.7.mlp.fc1.weight", "pretrained.blocks.7.mlp.fc1.bias", "pretrained.blocks.7.mlp.fc2.weight", "pretrained.blocks.7.mlp.fc2.bias", "pretrained.blocks.7.ls2.gamma", "pretrained.blocks.8.norm1.weight", "pretrained.blocks.8.norm1.bias", "pretrained.blocks.8.attn.qkv.weight", "pretrained.blocks.8.attn.qkv.bias", "pretrained.blocks.8.attn.proj.weight", "pretrained.blocks.8.attn.proj.bias", "pretrained.blocks.8.ls1.gamma", "pretrained.blocks.8.norm2.weight", "pretrained.blocks.8.norm2.bias", "pretrained.blocks.8.mlp.fc1.weight", "pretrained.blocks.8.mlp.fc1.bias", "pretrained.blocks.8.mlp.fc2.weight", "pretrained.blocks.8.mlp.fc2.bias", "pretrained.blocks.8.ls2.gamma", "pretrained.blocks.9.norm1.weight", "pretrained.blocks.9.norm1.bias", "pretrained.blocks.9.attn.qkv.weight", "pretrained.blocks.9.attn.qkv.bias", "pretrained.blocks.9.attn.proj.weight", "pretrained.blocks.9.attn.proj.bias", "pretrained.blocks.9.ls1.gamma", "pretrained.blocks.9.norm2.weight", "pretrained.blocks.9.norm2.bias", "pretrained.blocks.9.mlp.fc1.weight", "pretrained.blocks.9.mlp.fc1.bias", "pretrained.blocks.9.mlp.fc2.weight", "pretrained.blocks.9.mlp.fc2.bias", "pretrained.blocks.9.ls2.gamma", "pretrained.blocks.10.norm1.weight", "pretrained.blocks.10.norm1.bias", "pretrained.blocks.10.attn.qkv.weight", "pretrained.blocks.10.attn.qkv.bias", "pretrained.blocks.10.attn.proj.weight", "pretrained.blocks.10.attn.proj.bias", "pretrained.blocks.10.ls1.gamma", "pretrained.blocks.10.norm2.weight", "pretrained.blocks.10.norm2.bias", "pretrained.blocks.10.mlp.fc1.weight", "pretrained.blocks.10.mlp.fc1.bias", "pretrained.blocks.10.mlp.fc2.weight", "pretrained.blocks.10.mlp.fc2.bias", "pretrained.blocks.10.ls2.gamma", "pretrained.blocks.11.norm1.weight", "pretrained.blocks.11.norm1.bias", "pretrained.blocks.11.attn.qkv.weight", "pretrained.blocks.11.attn.qkv.bias", "pretrained.blocks.11.attn.proj.weight", "pretrained.blocks.11.attn.proj.bias", "pretrained.blocks.11.ls1.gamma", "pretrained.blocks.11.norm2.weight", "pretrained.blocks.11.norm2.bias", "pretrained.blocks.11.mlp.fc1.weight", "pretrained.blocks.11.mlp.fc1.bias", "pretrained.blocks.11.mlp.fc2.weight", "pretrained.blocks.11.mlp.fc2.bias", "pretrained.blocks.11.ls2.gamma", "pretrained.norm.weight", "pretrained.norm.bias", "depth_head.projects.0.weight", "depth_head.projects.0.bias", "depth_head.projects.1.weight", "depth_head.projects.1.bias", "depth_head.projects.2.weight", "depth_head.projects.2.bias", "depth_head.projects.3.weight", "depth_head.projects.3.bias", "depth_head.resize_layers.0.weight", "depth_head.resize_layers.0.bias", "depth_head.resize_layers.1.weight", "depth_head.resize_layers.1.bias", "depth_head.resize_layers.3.weight", "depth_head.resize_layers.3.bias", "depth_head.scratch.layer1_rn.weight", "depth_head.scratch.layer2_rn.weight", "depth_head.scratch.layer3_rn.weight", "depth_head.scratch.layer4_rn.weight", "depth_head.scratch.refinenet1.out_conv.weight", "depth_head.scratch.refinenet1.out_conv.bias", "depth_head.scratch.refinenet1.resConfUnit1.conv1.weight", "depth_head.scratch.refinenet1.resConfUnit1.conv1.bias", "depth_head.scratch.refinenet1.resConfUnit1.conv2.weight", "depth_head.scratch.refinenet1.resConfUnit1.conv2.bias", "depth_head.scratch.refinenet1.resConfUnit2.conv1.weight", "depth_head.scratch.refinenet1.resConfUnit2.conv1.bias", "depth_head.scratch.refinenet1.resConfUnit2.conv2.weight", "depth_head.scratch.refinenet1.resConfUnit2.conv2.bias", "depth_head.scratch.refinenet2.out_conv.weight", "depth_head.scratch.refinenet2.out_conv.bias", "depth_head.scratch.refinenet2.resConfUnit1.conv1.weight", "depth_head.scratch.refinenet2.resConfUnit1.conv1.bias", "depth_head.scratch.refinenet2.resConfUnit1.conv2.weight", "depth_head.scratch.refinenet2.resConfUnit1.conv2.bias", "depth_head.scratch.refinenet2.resConfUnit2.conv1.weight", "depth_head.scratch.refinenet2.resConfUnit2.conv1.bias", "depth_head.scratch.refinenet2.resConfUnit2.conv2.weight", "depth_head.scratch.refinenet2.resConfUnit2.conv2.bias", "depth_head.scratch.refinenet3.out_conv.weight", "depth_head.scratch.refinenet3.out_conv.bias", "depth_head.scratch.refinenet3.resConfUnit1.conv1.weight", "depth_head.scratch.refinenet3.resConfUnit1.conv1.bias", "depth_head.scratch.refinenet3.resConfUnit1.conv2.weight", "depth_head.scratch.refinenet3.resConfUnit1.conv2.bias", "depth_head.scratch.refinenet3.resConfUnit2.conv1.weight", "depth_head.scratch.refinenet3.resConfUnit2.conv1.bias", "depth_head.scratch.refinenet3.resConfUnit2.conv2.weight", "depth_head.scratch.refinenet3.resConfUnit2.conv2.bias", "depth_head.scratch.refinenet4.out_conv.weight", "depth_head.scratch.refinenet4.out_conv.bias", "depth_head.scratch.refinenet4.resConfUnit1.conv1.weight", "depth_head.scratch.refinenet4.resConfUnit1.conv1.bias", "depth_head.scratch.refinenet4.resConfUnit1.conv2.weight", "depth_head.scratch.refinenet4.resConfUnit1.conv2.bias", "depth_head.scratch.refinenet4.resConfUnit2.conv1.weight", "depth_head.scratch.refinenet4.resConfUnit2.conv1.bias", "depth_head.scratch.refinenet4.resConfUnit2.conv2.weight", "depth_head.scratch.refinenet4.resConfUnit2.conv2.bias", "depth_head.scratch.output_conv1.weight", "depth_head.scratch.output_conv1.bias", "depth_head.scratch.output_conv2.0.weight", "depth_head.scratch.output_conv2.0.bias", "depth_head.scratch.output_conv2.2.weight", "depth_head.scratch.output_conv2.2.bias". Unexpected key(s) in state_dict: "model", "optimizer", "epoch".

I expected the result of the fine-tuned model to be depth-anything, but Upon checking, it appears that the zoedepth model is being trained. Is it possible to make absolute value output from depth-anything itself? and Is there a script that can do inference using a fine-tuned model?

Anyone who has success with this please help me.

1ssb commented 6 months ago

The depth anything metric distance model is different from the normal depth-anything model. Make sure you are downloading the correct version and finetuning the correct version. I suggest going through zoedepth training instructions.

Choi-YeongJoon commented 6 months ago

Thank you for your advice! I followed your advice, finally, looked at the zoedepth guide, and succeeded in parsing onnx. I am then trying to parse with tensorrt. I'll let you know how it goes!

1ssb commented 6 months ago

Tensorrt has some pitfalls including precision point, beware of how the ONNX model translates that for your pytorch model which is I am guessing will be trained at float32 or 64.

Choi-YeongJoon commented 6 months ago

Thank you for your sensitive response. Depending on the requirements given to me, I think I can use a sufficiently large precision. However, there are environmental restrictions. It is a limitation of the Tensorrt version, and there appear to be some operations that are not supported in version TRT 8.2. I'm going to construct a custom operation and parse it. Do you have any experience doing this before?

1ssb commented 6 months ago

No, I am afraid, I have no expertise with this. But you should check out the ONNX backend for Tensorrt using the execution provider.

Choi-YeongJoon commented 6 months ago

The only unsupported operation is cubic interpolation. This is supported in version 8.6 or higher, but is not supported elsewhere. If it is resolved, I will share it.

1ssb commented 6 months ago

That actually is good enough.

alifoladi3333 commented 5 months ago

hi , can you tell me how to download the dataset and how to fine tune the depth anything to get metric depth estimation ? the instruction in this repository isnt clear for me :(

HaosenZ commented 5 months ago

Hi, I'm very happy that you solved the problem. I also trained my dataset to obtain finetuned model files, but I don't know where to use the model files to predict images and obtain depth maps. Do you know how to solve this problem? Thank you.