huggingface / notebooks

Notebooks using the Hugging Face libraries 🤗
Apache License 2.0
3.66k stars 1.53k forks source link

prediction after loading the fine tuned model fails: 'BaseModelOutput' object has no attribute 'logits' #106

Open ShrikanthSingh opened 3 years ago

ShrikanthSingh commented 3 years ago

I extended the training and evaluation process here https://huggingface.co/transformers/custom_datasets.html#fine-tuning-with-native-pytorch-tensorflow to save the fine-tuned model and use it for prediction separately. Here is the code for it.

true_labels, predicted_labels = [], []
model.eval()

for batch in eval_dataloader:

    batch_labels = batch['labels'].numpy()
    true_labels.extend(batch_labels)

    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    batch_predictions = predictions.to('cpu').numpy()
    predicted_labels.extend(batch_predictions)

model.save_pretrained('imdb_custom_dataset')
from transformers import AutoModel
model = AutoModel.from_pretrained("/content/imdb_custom_dataset")

When I try to predict using the loaded model I encounter this error AttributeError: 'BaseModelOutput' object has no attribute 'logits' . The code used for it is below.

model.eval()
for batch in test_dataloader:
    break
test_sample = {k: v for k, v in batch.items() if k != 'labels'}
outputs_sample = model(**test_sample)
logits_sample = outputs_sample.logits

Error details:

AttributeError                            Traceback (most recent call last)
<ipython-input-20-a99e37f72baa> in <module>()
      4 test_sample = {k: v for k, v in batch.items() if k != 'labels'}
      5 outputs_sample = model(**test_sample)
----> 6 logits_sample = outputs_sample.logits

AttributeError: 'BaseModelOutput' object has no attribute 'logits'

Any help on this issue ? Thank you

statspy-ml commented 2 years ago

wondering how did you solve this issue as I am facing the same problem..

brestok-1 commented 1 year ago

Did you solve this problem?

ericxsun commented 3 weeks ago

You could try loading the model with a specific model class instead of using AutoModel.