UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.23k stars 2.47k forks source link

CrossEncoder RuntimeError during torch.jit.trace: Cannot insert a Tensor that requires grad as a constant #3012

Open Temchaz opened 2 weeks ago

Temchaz commented 2 weeks ago

I encountered an issue when attempting to trace a CrossEncoder model using torch.jit.trace. The error occurs during the tracing process when calling the forward method. Below is a minimal reproducible code example and the error message.

import torch
import onnx 
import torch.nn.functional as F
from torch import nn

from sentence_transformers import CrossEncoder

class CategoryMatcher(nn.Module):
    def __init__(self, path_to_weights, encoder_max_lengths):
        super().__init__()
        self.model = CrossEncoder(path_to_weights, max_length=encoder_max_lengths)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            output = self.model.model(input_ids=input_ids, attention_mask=attention_mask)

        return torch.sigmoid(output.logits)

category_matcher = CategoryMatcher("/Users/a.mitrokhin/work/baga/category-matcher/1/category-matcher", 512)
category_matcher.eval()

category_name = "Winter"
good_name = "Winter is comming"
texts=[(category_name, good_name)]

encoded_input = category_matcher.model.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
input_ids = encoded_input["input_ids"].detach()
attention_mask = encoded_input["attention_mask"].detach()

scripted_model = torch.jit.trace(category_matcher, example_inputs=(input_ids, attention_mask))
torch.jit.save(scripted_model, "./category-matcher.pt")
Traceback (most recent call last):
  File "/Users/a.mitrokhin/work/baga/toTorch/exper.py", line 38, in <module>
    scripted_model = torch.jit.trace(category_matcher, example_inputs=(input_ids, attention_mask))
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 1002, in trace
    traced_func = _trace_impl(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 698, in _trace_impl
    return trace_module(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 1278, in trace_module
    module._c._create_method_from_trace(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/a.mitrokhin/work/baga/toTorch/exper.py", line 15, in forward
    output = self.model.model(input_ids=input_ids, attention_mask=attention_mask)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py", line 1348, in forward
    outputs = self.roberta(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py", line 913, in forward
    embedding_output = self.embeddings(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py", line 123, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward
    return F.embedding(
  File "/Users/a.mitrokhin/work/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2529, in embedding
    assert padding_idx < weight.size(
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
tomaarsen commented 1 week ago

Hello!

I'm able to reproduce this, but I'm not familiar with JIT trace at all, so I'm not sure how to resolve it.

Temchaz commented 1 week ago

@tomaarsen hi, could you please tell me how to serve the CrossEncoder model without using the Triton Python backend?

tomaarsen commented 1 hour ago

I'm not sure what you mean exactly with serving it - do you mean over an API? Triton isn't used by default when you're running the model normally, like so:

from sentence_transformers import CrossEncoder

# 1. Load a pre-trained CrossEncoder model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# 2. Predict scores for a pair of sentences
scores = model.predict([
    ("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
    ("How many people live in Berlin?", "Berlin is well known for its museums."),
])
# => array([ 8.607138 , -4.3200774], dtype=float32)

# 3. Rank a list of passages for a query
query = "How many people live in Berlin?"
passages = [
    "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
    "Berlin is well known for its museums.",
    "In 2014, the city state Berlin had 37,368 live births (+6.6%), a record number since 1991.",
    "The urban area of Berlin comprised about 4.1 million people in 2014, making it the seventh most populous urban area in the European Union.",
    "The city of Paris had a population of 2,165,423 people within its administrative city limits as of January 1, 2019",
    "An estimated 300,000-420,000 Muslims reside in Berlin, making up about 8-11 percent of the population.",
    "Berlin is subdivided into 12 boroughs or districts (Bezirke).",
    "In 2015, the total labour force in Berlin was 1.85 million.",
    "In 2013 around 600,000 Berliners were registered in one of the more than 2,300 sport and fitness clubs.",
    "Berlin has a yearly total of about 135 million day visitors, which puts it in third place among the most-visited city destinations in the European Union.",
]
ranks = model.rank(query, passages)

# Print the scores
print("Query:", query)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{passages[rank['corpus_id']]}")
"""
Query: How many people live in Berlin?
8.92    The urban area of Berlin comprised about 4.1 million people in 2014, making it the seventh most populous urban area in the European Union.
8.61    Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.
8.24    An estimated 300,000-420,000 Muslims reside in Berlin, making up about 8-11 percent of the population.
7.60    In 2014, the city state Berlin had 37,368 live births (+6.6%), a record number since 1991.
6.35    In 2013 around 600,000 Berliners were registered in one of the more than 2,300 sport and fitness clubs.
5.42    Berlin has a yearly total of about 135 million day visitors, which puts it in third place among the most-visited city destinations in the European Union.
3.45    In 2015, the total labour force in Berlin was 1.85 million.
0.33    Berlin is subdivided into 12 boroughs or districts (Bezirke).
-4.24   The city of Paris had a population of 2,165,423 people within its administrative city limits as of January 1, 2019
-4.32   Berlin is well known for its museums.
"""

You can serve this via e.g. FastAPI or infinity or whatever tooling you prefer.