kagami-tsukimura / pytorch-deeplearning

0 stars 0 forks source link

onnx #77

Open kagami-tsukimura opened 2 weeks ago

kagami-tsukimura commented 2 weeks ago

import torch

学習済みモデルの例

model = ... # 学習済みのDETRモデルを取得

モデルの保存

torch.save(model.state_dict(), "detr_model.pth")

kagami-tsukimura commented 2 weeks ago

import torch from transformers import DetrForObjectDetection

モデルの定義

model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')

学習済みの重みをロード

model.load_state_dict(torch.load("detr_model.pth"))

推論モードに切り替え

model.eval()

ダミー入力(バッチサイズ1、チャネル数3、高さ224、幅224)

dummy_input = torch.randn(1, 3, 224, 224)

モデルをONNX形式でエクスポート

torch.onnx.export(model, dummy_input, "detr_model.onnx", opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'])

kagami-tsukimura commented 2 weeks ago

import onnxruntime as ort import numpy as np from PIL import Image from transformers import DetrFeatureExtractor

ONNXモデルのロード

ort_session = ort.InferenceSession("detr_model.onnx")

画像の前処理

image = Image.open("path_to_image.jpg") feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50') inputs = feature_extractor(images=image, return_tensors="np")

推論

outputs = ort_session.run(None, {'input': inputs['pixel_values']}) logits = outputs[0]

結果の処理(例えば、ボックスやラベルの抽出)