MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
268 stars 61 forks source link

How to get mean embeddings from model #8

Closed timoast closed 1 year ago

timoast commented 1 year ago

Hi, do you have any examples of how to extract sequence embeddings using this model?

I tried the following code but get an error:

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModelForMaskedLM.from_pretrained("zhihan1996/DNABERT-2-117M")

tok = tokenizer("ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC", return_tensors = 'pt')
outs = model(tok)

Gives error below:

```python --------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:254, in BatchEncoding.__getattr__(self, item) 253 try: --> 254 return self.data[item] 255 except KeyError: KeyError: 'size' During handling of the above exception, another exception occurred: AttributeError Traceback (most recent call last) Cell In[67], line 1 ----> 1 outs = model(tok) File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:1358, in BertForMaskedLM.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, output_attentions, output_hidden_states, return_dict) 1349 r""" 1350 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1351 Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 1352 config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 1353 loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 1354 """ 1356 return_dict = return_dict if return_dict is not None else self.config.use_return_dict -> 1358 outputs = self.bert( 1359 input_ids, 1360 attention_mask=attention_mask, 1361 token_type_ids=token_type_ids, 1362 position_ids=position_ids, 1363 head_mask=head_mask, 1364 inputs_embeds=inputs_embeds, 1365 encoder_hidden_states=encoder_hidden_states, 1366 encoder_attention_mask=encoder_attention_mask, 1367 output_attentions=output_attentions, 1368 output_hidden_states=output_hidden_states, 1369 return_dict=return_dict, 1370 ) 1372 sequence_output = outputs[0] 1373 prediction_scores = self.cls(sequence_output) File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:968, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict) 966 raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 967 elif input_ids is not None: --> 968 input_shape = input_ids.size() 969 elif inputs_embeds is not None: 970 input_shape = inputs_embeds.size()[:-1] File ~/mambaforge/envs/predictor/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:256, in BatchEncoding.__getattr__(self, item) 254 return self.data[item] 255 except KeyError: --> 256 raise AttributeError AttributeError: ```
Zhihan1996 commented 1 year ago

Hi,

Thank for posting the question.

To load the model from huggingface:

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

To calculate the embedding of a dna sequence

dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768

I will also update the model card and README for the instructions.