rapidsai / crossfit

Metric calculation library
Apache License 2.0
2 stars 6 forks source link

Allow HF and sentence-transformer models #63

Closed VibhuJawa closed 3 months ago

VibhuJawa commented 3 months ago

This PR fixes a bug that allows us to support HF and sentence-transformer models which can fail depending on how the model expects the forward pass.

Example of the delta it fixes.

HF:

model_hf = DistilBertModel.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# Hugging Face model output
outputs_hf = model_hf(**inputs)

ST:

model_st = sentence_transformers.SentenceTransformer("all-MiniLM-L6-v2").to("cpu")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

inputs = tokenizer(
    ["Hello", "my dog is cute"], return_tensors="pt", padding=True, truncation=True
)
# Sentence Transformers model output
expected_output = model_st(inputs)