Open kagami-tsukimura opened 2 weeks ago
import torch from transformers import DetrForObjectDetection
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
model.load_state_dict(torch.load("detr_model.pth"))
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "detr_model.onnx", opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'])
import onnxruntime as ort import numpy as np from PIL import Image from transformers import DetrFeatureExtractor
ort_session = ort.InferenceSession("detr_model.onnx")
image = Image.open("path_to_image.jpg") feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50') inputs = feature_extractor(images=image, return_tensors="np")
outputs = ort_session.run(None, {'input': inputs['pixel_values']}) logits = outputs[0]
import torch
学習済みモデルの例
model = ... # 学習済みのDETRモデルを取得
モデルの保存
torch.save(model.state_dict(), "detr_model.pth")