MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
212 stars 49 forks source link

Inference fails with output_all_encoded_layers=True #48

Open princethewinner opened 9 months ago

princethewinner commented 9 months ago

I am trying to extract hidden layer output from all the layers in the model. As per the documentation, the output_all_encoded_layers: boolean which controls the content of the encoded_layers output as described below. Default: True.. However Line 586 (https://huggingface.co/zhihan1996/DNABERT-2-117M/blob/main/bert_layers.py#L586) has this set to False, which I was expecting to be the case in contrast to what documentation says because only last layer was returned in the output. However, when I set it to True the inference fails. The traceback is as follows:


RuntimeError Traceback (most recent call last) Cell In[60], line 1 ----> 1 output = model(**b, output_all_encoded_layers=True)

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /lustre/scratch124/casm/team113/users/pg20/data/supporting/huggingface_models/modules/transformers_modules/zhihan1996/DNABERT-2-117M/81ac6a98387cf94bc283553260f3fa6b88cef2fa/bert_layers.py:616, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs) 614 if masked_tokens_mask is None: 615 sequence_output = encoder_outputs[-1] --> 616 pooled_output = self.pooler( 617 sequence_output) if self.pooler is not None else None 618 else: 619 # TD [2022-03-01]: the indexing here is very tricky. 620 attention_mask_bool = attention_mask.bool()

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /lustre/scratch124/casm/team113/users/pg20/data/supporting/huggingface_models/modules/transformers_modules/zhihan1996/DNABERT-2-117M/81ac6a98387cf94bc283553260f3fa6b88cef2fa/bert_layers.py:501, in BertPooler.forward(self, hidden_states, pool) 495 def forward(self, 496 hidden_states: torch.Tensor, 497 pool: Optional[bool] = True) -> torch.Tensor: 498 # We "pool" the model by simply taking the hidden state corresponding 499 # to the first token. 500 first_token_tensor = hidden_states[:, 0] if pool else hidden_states --> 501 pooled_output = self.dense(first_token_tensor) 502 pooled_output = self.activation(pooled_output) 503 return pooled_output

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /lustre/scratch124/casm/team113/users/pg20/venvs/huggingface/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x5 and 768x768)

Steps to reproduce the error: import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) b = tokenizer('ATCG', return_tensors='pt', return_attention_mask=True) output = model(**b, output_all_encoded_layers=True)

P.S. I am not using triton since it was failing in another step.