huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.73k stars 943 forks source link

How to pass the attention_mask to Bert model in examples? #1552

Closed lz1998 closed 9 months ago

lz1998 commented 9 months ago

I am trying to run shibing624/text2vec-base-chinese with candle, and the encoder returns input_ids, attention_mask, token_id_types, but there are only two params of BertModel in candle.

https://github.com/huggingface/candle/blob/main/candle-examples/examples/bert/main.rs#L170

from transformers import BertTokenizer, BertModel
import torch

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load model from HuggingFace Hub
tokenizer = BertTokenizer.from_pretrained('shibing624/text2vec-base-chinese')
model = BertModel.from_pretrained('shibing624/text2vec-base-chinese')
sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡']
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)
# Perform pooling. In this case, mean pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("Sentence embeddings:")
print(sentence_embeddings)
LaurentMazare commented 9 months ago

The bert model in candle doesn't use any attention mask as it's not setup for autoregressive inference but rather for computing embeddings, so there is no real attention mask to be applied down the line when doing mean-pooling etc (each sequence position can attend to all the other sequence positions).

lz1998 commented 9 months ago

The bert model in candle doesn't use any attention mask as it's not setup for autoregressive inference but rather for computing embeddings, so there is no real attention mask to be applied down the line when doing mean-pooling etc (each sequence position can attend to all the other sequence positions).

Is there something like SentenceTransformer?