Closed kyakuno closed 1 month ago
class Exp(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.activation = Sigmoid()
def forward(self, input_ids, attention_mask, token_type_ids):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
logits = self.model(**inputs).logits
scores = self.activation(logits)
return scores
if 1:
with torch.no_grad():
print("------>")
from torch.autograd import Variable
model = Exp(model)
xx = (
Variable(inputs["input_ids"]),
Variable(inputs["attention_mask"]),
Variable(inputs["token_type_ids"]),
)
torch.onnx.export(
model,
xx,
"xxx.onnx",
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["scores"],
dynamic_axes={
"input_ids": [0, 1],
"attention_mask": [0, 1],
"token_type_ids": [0, 1],
"scores": [0],
},
verbose=False,
opset_version=17,
)
print("<------")
1 / 0
モデル:https://huggingface.co/hotchpotch/japanese-reranker-cross-encoder-large-v1 Rerankerの使い方:https://note.com/npaka/n/n906b23636ac8?sub_rt=share_h Rerankerの概要:https://secon.dev/entry/2024/04/02/070000-japanese-reranker-release/ CrossEncoderについて:https://qiita.com/warper/items/fd84e740e62ad1a67703 mit