naver / splade

SPLADE: sparse neural search (SIGIR21, SIGIR22)
Other
751 stars 84 forks source link

Tutorial to export a SPLADE model to ONNX #47

Open ntnq4 opened 10 months ago

ntnq4 commented 10 months ago

Hello,

I trained a SPLADE model on my own recently. To reduce the inference time, I tried to export my model to ONNX with torch.onnx.export() but I encountered a few errors.

Is there a tutorial somewhere for this conversion?

thibault-formal commented 9 months ago

Hi @ntnq4

Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work?

ntnq4 commented 9 months ago

Hi @thibault-formal

I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model.

I also found this recent paper that mentionned this conversion.

risan-raja commented 8 months ago

Hi @ntnq4 , I have managed to convert the splade models to onnx. Although I used the pretrained checkpoint. I am aware it is counterintuitive for you but nevertheless if this helps, I am glad. To reproduce:

model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore

import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore

class TransformerRep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
        self.model.eval() # type: ignore
        self.fp16 = True

    def encode(self, input_ids, token_type_ids, attention_mask):
        # Tokens is a dict with keys input_ids and attention_mask
        return self.model(input_ids, token_type_ids, attention_mask)[0]

class SpladeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerRep()
        self.agg = "max"
        self.model.eval()

    def forward(self, input_ids,token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(): # type: ignore
            with torch.no_grad():
                lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
                vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
                indices = vec.nonzero().squeeze()
                weights = vec.squeeze()[indices]
        return indices[:,1], weights[:,1]

# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))

requirements:

Hope this helps! :)

ntnq4 commented 8 months ago

Hi @risan-raja,

Thank you for your help : ) I will try your solution on my side.

sroussey commented 8 months ago

if an ONNX conversion was added to HuggingFace in a folder called onnx then it would automatically become available to HuggingFace Transformers.js and be usable locally on the web.

sroussey commented 8 months ago

Example: https://huggingface.co/Xenova/t5-small-awesome-text-to-sql/tree/main/