ThilinaRajapakse / simpletransformers

Transformers for Information Retrieval, Text Classification, NER, QA, Language Modelling, Language Generation, T5, Multi-Modal, and Conversational AI
https://simpletransformers.ai/
Apache License 2.0
4.07k stars 727 forks source link

Binary classification .predict raises a ValueError: could not broadcast input array from shape (2,2) into shape (1,2) #1563

Open gaborfodor720818 opened 7 months ago

gaborfodor720818 commented 7 months ago

Hi,

I have tried the original simpletransformers sample code on my local python with the latest simpletransformers version 0.65.1. I have trained the model using bert model type and "bert-base-uncased". Training was done, but the prediction raises always an error.

The code: predictions, raw_outputs = model.predict(to_predict=["Example sentence belonging to class 1"]) Error: predictions, raw_outputs = model.predict(to_predict=["Example sentence belonging to class 1"]) File "C:\Work\PythonProjects\TenderAI\venv\lib\site-packages\simpletransformers\classification\classification_model.py", line 2217, in predict preds[start_index:end_index] = logits.detach().cpu().numpy() ValueError: could not broadcast input array from shape (2,2) into shape (1,2)

ThilinaRajapakse commented 7 months ago
  1. Does this only happen with the latest version?
  2. Does it only happen with the model that you trained with an earlier version when trying to use it with the current version?

I tried to reproduce it with just bert-base-uncased but it's working fine for me.

gaborfodor720818 commented 7 months ago
  1. Does this only happen with the latest version?
  2. Does it only happen with the model that you trained with an earlier version when trying to use it with the current version?

I tried to reproduce it with just bert-base-uncased but it's working fine for me.

Strange, now it works. I have tried on a linux machine (previously it was W11) with a clean setup. Simpletransformers is the latest version again 0.65.1. ( I have never tried with older version.) Thanks for testing.

gaborfodor720818 commented 7 months ago

I am not sure it is a bug, but I have found the reason. If I set "use_cached_eval_features": True I get the error message. I have also this "evaluate_during_training": True, because I do the evaluation during the training: model.train_model(train_df, eval_df=eval_df).