utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

accuracy function only gives zero-output and breaks for new numpy version #318

Open pbouss opened 9 months ago

pbouss commented 9 months ago

The accuracy function in metrics.py does not work properly.

def accuracy(y_pred: Tensor, y_true: Tensor, **kwargs): y_pred = y_pred.cpu() outputs = np.argmax(y_pred, axis=1) return np.mean(outputs.numpy() == y_true.detach().cpu().numpy())

This is due to outputs having less axes than y_true. In the newest numpy 1.26 this raises in error. In older numpy versions this gives an DeprecationWarning and outputs False, due to the mean it becomes 0. Therefore this functions outputs constantly zero. Possibly it could be adapted easily.