facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.91k stars 783 forks source link

[request] Additional model exports (ONNX, CoreML, ...) #167

Open patricklabatut opened 1 year ago

patricklabatut commented 1 year ago

Related:

barbolo commented 6 months ago

This is how I've used transformers to export dinov2 outputs with class token + patch tokens for ONNX and OpenVINO.

import torch
from transformers import Dinov2Model

image_width = 224
image_height = 224
model_size = 'small' # small, base, large, giant

class Wrapper(torch.nn.Module):
    def __init__(self, dinov2_model):
        super().__init__()
        self.dinov2_model = dinov2_model
    def forward(self, tensor):
        return self.dinov2_model(tensor).last_hidden_state

dummy_input = torch.rand([1, 3, image_height, image_width]).to('cpu')

dinov2_model = Dinov2Model.from_pretrained(f'facebook/dinov2-{model_size}')
model = Wrapper(dinov2_model).to('cpu')

torch.onnx.export(model, dummy_input, f'dinov2-{model_size}.onnx')

Once you have the ONNX model (e.g. dinov2-small.onnx), you might convert it to OpenVINO with fp16 using the ovc CLI:

ovc dinov2-small.onnx --output_model openvino/dinov2-small --compress_to_fp16=True

This is how to use it after conversion:

# onnx
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('dinov2-small.onnx')
model_inputs = session.get_inputs()
input_shape = model_inputs[0].shape
input_height = input_shape[2]
input_width = input_shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = session.run(None, { model_inputs[0].name: dummy_input })
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:]

# openvino
from openvino.runtime import Core
import numpy as np
core = Core()
model = core.read_model(model='openvino/dinov2-small.xml')
compiled_model = core.compile_model(model=model, device_name="CPU")
input_height = compiled_model.input(0).shape[2]
input_width = compiled_model.input(0).shape[3]
dummy_input = np.random.rand(1, 3, input_height, input_width).astype(np.float32)
outputs = compiled_model(dummy_input)
classtoken = outputs[0][0][0]
patchtokens = outputs[0][0][1:]
barbolo commented 6 months ago

Before feeding an image to the model, you should preprocess it. I've written this function in order to do that:

import cv2 as cv
import numpy as np

# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/data/transforms.py#L75-L91
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MEAN = np.array(IMAGENET_DEFAULT_MEAN)
STD = np.array(IMAGENET_DEFAULT_STD)
def preprocess(img):
    if isinstance(img, np.ndarray):
        # CV image
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    else:
        # PIL Image
        img = np.array(img)
    img = cv.resize(img, (input_width, input_height), interpolation=cv.INTER_CUBIC)
    img = np.array(img) / 255.0 # from [0, 255] to [0.0, 1.0]
    img = np.transpose(img, (2, 0, 1)) # from shape (H x W x C) to (C x H x W)
    img = (img - MEAN[:, None, None]) / STD[:, None, None] # transforms.Normalize(mean=MEAN, std=STD)
    img = img.astype(np.float32)
    return img
ahmed1996said commented 1 month ago

Thanks @barbolo! Have you had a chance to convert a model you've trained to ONNX? I'm having issues aligning state dict keys with the HF model, as the architecture seems to differ a bit.

barbolo commented 1 month ago

@ahmed1996said I didn't train any DINOv2 model. I've converted pretrained transformers/torch models to ONNX/OpenVINO.

I'm currently using this script to convert DINOv2 models (model without extra registers) from torch to openvino:

import openvino as ov
import os
import torch

OPENVINO_SIZES = [112, 140, 392]

# xformers graph failing in OpenVINO.
os.environ['XFORMERS_DISABLED'] = '1'
torch.set_default_device('cpu')

dinov2_dir = os.path.join(os.path.dirname(__file__), '../downloads/dinov2')
os.makedirs(dinov2_dir, exist_ok=True)
torch.hub.set_dir(dinov2_dir)

# Wrapper for extracting all features from a torch dinov2 model
class DinoV2Features(torch.nn.Module):
    def __init__(self, dinov2_model):
        super().__init__()
        self.dinov2_model = dinov2_model
    def forward(self, tensor_input):
        features = self.dinov2_model.forward_features(tensor_input)
        return torch.cat([
            features['x_norm_clstoken'].unsqueeze(1),
            features['x_norm_patchtokens'],
        ], dim=1)

dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
model = DinoV2Features(dinov2_model)

core = ov.Core()
for size in OPENVINO_SIZES:
  dummy_input = torch.rand([1, 3, size, size]).to('cpu')
  ov_model = ov.convert_model(model, example_input=dummy_input, input=dummy_input.shape)
  ov.save_model(ov_model, f"{dinov2_dir}/openvino/{size}.xml", compress_to_fp16=True)
ahmed1996said commented 1 month ago

@barbolo Thanks for your prompt response, I'm now able to load my training weights directly to the torch-hub model! How did you disable the flash attention though? I'm running into errors when I try to convert it to ONNX. image

barbolo commented 1 month ago

@ahmed1996said I'm pretty sure this was the fix for the flash attention:

# xformers graph failing in OpenVINO.
os.environ['XFORMERS_DISABLED'] = '1'
torch.set_default_device('cpu')

XFORMERS_DISABLED=1 needs to be set before loading dinov2 code/model from torch hub (I've set this var inside the python script for conversion, but you can set it in your execution environment if you prefer.). It changes the model graph, including the type of attention layers.

https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/attention.py#L21-L33

ahmed1996said commented 1 month ago

Ah perfect, thanks a lot @barbolo. Its working fine now :)