axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2k stars 318 forks source link

ADD japanese-reranker-cross-encoder-large-v1 #1443

Closed kyakuno closed 1 month ago

kyakuno commented 5 months ago

モデル:https://huggingface.co/hotchpotch/japanese-reranker-cross-encoder-large-v1 Rerankerの使い方:https://note.com/npaka/n/n906b23636ac8?sub_rt=share_h Rerankerの概要:https://secon.dev/entry/2024/04/02/070000-japanese-reranker-release/ CrossEncoderについて:https://qiita.com/warper/items/fd84e740e62ad1a67703 mit

ooe1123 commented 4 months ago
class Exp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.activation = Sigmoid()

    def forward(self, input_ids, attention_mask, token_type_ids):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }
        logits = self.model(**inputs).logits
        scores = self.activation(logits)
        return scores

if 1:
    with torch.no_grad():
        print("------>")
        from torch.autograd import Variable

        model = Exp(model)
        xx = (
            Variable(inputs["input_ids"]),
            Variable(inputs["attention_mask"]),
            Variable(inputs["token_type_ids"]),
        )
        torch.onnx.export(
            model,
            xx,
            "xxx.onnx",
            input_names=["input_ids", "attention_mask", "token_type_ids"],
            output_names=["scores"],
            dynamic_axes={
                "input_ids": [0, 1],
                "attention_mask": [0, 1],
                "token_type_ids": [0, 1],
                "scores": [0],
            },
            verbose=False,
            opset_version=17,
        )
        print("<------")
        1 / 0