huggingface / evaluate

🤗 Evaluate: A library for easily evaluating machine learning models and datasets.
https://huggingface.co/docs/evaluate
Apache License 2.0
1.93k stars 243 forks source link

Can't calculate combined metric (WER + CER) #516

Open blademoon opened 9 months ago

blademoon commented 9 months ago

Hello. Reproduction code:

import evaluate

asr_metrics = evaluate.combine(["wer","cer"])

predictions = ["this is the prediction", "there is an other sample"]
references = ["this is the reference", "there is another one"]

asr_metrics.compute(predictions=predictions, references=references)

Code output:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[15], line 4
      1 predictions = ["this is the prediction", "there is an other sample"]
      2 references = ["this is the reference", "there is another one"]
----> 4 asr_metrics.compute(predictions=predictions, references=references)

File ~/.local/lib/python3.10/site-packages/evaluate/module.py:976, in CombinedEvaluations.compute(self, predictions, references, **kwargs)
    973     batch = {"predictions": predictions, "references": references, **kwargs}
    974     results.append(evaluation_module.compute(**batch))
--> 976 return self._merge_results(results)

File ~/.local/lib/python3.10/site-packages/evaluate/module.py:980, in CombinedEvaluations._merge_results(self, results)
    978 def _merge_results(self, results):
    979     merged_results = {}
--> 980     results_keys = list(itertools.chain.from_iterable([r.keys() for r in results]))
    981     duplicate_keys = {item for item, count in collections.Counter(results_keys).items() if count > 1}
    983     duplicate_names = [
    984         item for item, count in collections.Counter(self.evaluation_module_names).items() if count > 1
    985     ]

File ~/.local/lib/python3.10/site-packages/evaluate/module.py:980, in <listcomp>(.0)
    978 def _merge_results(self, results):
    979     merged_results = {}
--> 980     results_keys = list(itertools.chain.from_iterable([r.keys() for r in results]))
    981     duplicate_keys = {item for item, count in collections.Counter(results_keys).items() if count > 1}
    983     duplicate_names = [
    984         item for item, count in collections.Counter(self.evaluation_module_names).items() if count > 1
    985     ]

AttributeError: 'float' object has no attribute 'keys'

Python verion:

Python 3.10.12

Library versions:

Package                   Version
------------------------- -------------
aiofiles                  23.2.1
aiohttp                   3.9.0
aiosignal                 1.3.1
altair                    5.1.2
annotated-types           0.6.0
anyio                     3.7.1
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 2.4.1
async-lru                 2.0.4
async-timeout             4.0.3
attrs                     23.1.0
audioread                 3.0.1
Babel                     2.13.1
beautifulsoup4            4.12.2
bleach                    6.1.0
blinker                   1.4
certifi                   2023.11.17
cffi                      1.16.0
charset-normalizer        3.3.2
click                     8.1.7
colorama                  0.4.6
comm                      0.2.0
command-not-found         0.3
contourpy                 1.2.0
cryptography              3.4.8
cycler                    0.12.1
datasets                  2.15.0
dbus-python               1.2.18
debugpy                   1.8.0
decorator                 5.1.1
defusedxml                0.7.1
dill                      0.3.7
distro                    1.7.0
distro-info               1.1+ubuntu0.1
et-xmlfile                1.1.0
evaluate                  0.4.1
exceptiongroup            1.1.3
executing                 2.0.1
fastapi                   0.104.1
fastjsonschema            2.19.0
ffmpy                     0.3.1
filelock                  3.13.1
fonttools                 4.44.3
fqdn                      1.5.1
frozenlist                1.4.0
fsspec                    2023.10.0
gradio                    4.4.1
gradio_client             0.7.0
h11                       0.14.0
httpcore                  1.0.2
httplib2                  0.20.2
httpx                     0.25.1
huggingface-hub           0.19.4
idna                      3.4
importlib-metadata        4.6.4
importlib-resources       6.1.1
ipykernel                 6.26.0
ipython                   8.17.2
ipywidgets                8.1.1
isoduration               20.11.0
jedi                      0.19.1
jeepney                   0.7.1
Jinja2                    3.1.2
jiwer                     3.0.3
joblib                    1.3.2
json5                     0.9.14
jsonpointer               2.4
jsonschema                4.20.0
jsonschema-specifications 2023.11.1
jupyter_client            8.6.0
jupyter_core              5.5.0
jupyter-events            0.9.0
jupyter-lsp               2.2.0
jupyter_server            2.10.1
jupyter_server_terminals  0.4.4
jupyterlab                4.0.9
jupyterlab-pygments       0.2.2
jupyterlab_server         2.25.2
jupyterlab-widgets        3.0.9
keyring                   23.5.0
kiwisolver                1.4.5
launchpadlib              1.10.16
lazr.restfulclient        0.14.4
lazr.uri                  1.0.6
lazy_loader               0.3
librosa                   0.10.1
llvmlite                  0.41.1
markdown-it-py            3.0.0
MarkupSafe                2.1.3
matplotlib                3.8.2
matplotlib-inline         0.1.6
mdurl                     0.1.2
mistune                   3.0.2
more-itertools            8.10.0
mpmath                    1.3.0
msgpack                   1.0.7
multidict                 6.0.4
multiprocess              0.70.15
nbclient                  0.9.0
nbconvert                 7.11.0
nbformat                  5.9.2
nest-asyncio              1.5.8
netifaces                 0.11.0
networkx                  3.2.1
notebook_shim             0.2.3
numba                     0.58.1
numpy                     1.26.2
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu12          2.18.1
nvidia-nvjitlink-cu12     12.3.101
nvidia-nvtx-cu12          12.1.105
oauthlib                  3.2.0
openpyxl                  3.1.2
orjson                    3.9.10
overrides                 7.4.0
packaging                 23.2
pandas                    2.1.3
pandocfilters             1.5.0
parso                     0.8.3
pexpect                   4.8.0
Pillow                    10.1.0
pip                       23.3.1
platformdirs              4.0.0
plotly                    5.18.0
pooch                     1.8.0
prometheus-client         0.18.0
prompt-toolkit            3.0.41
psutil                    5.9.6
ptyprocess                0.7.0
pure-eval                 0.2.2
pyarrow                   14.0.1
pyarrow-hotfix            0.5
pycparser                 2.21
pydantic                  2.5.1
pydantic_core             2.14.3
pydub                     0.25.1
Pygments                  2.17.1
PyGObject                 3.42.1
PyJWT                     2.3.0
pyparsing                 2.4.7
python-apt                2.4.0+ubuntu2
python-dateutil           2.8.2
python-json-logger        2.0.7
python-multipart          0.0.6
pytz                      2023.3.post1
PyYAML                    5.4.1
pyzmq                     25.1.1
rapidfuzz                 3.5.2
referencing               0.31.0
regex                     2023.10.3
requests                  2.31.0
responses                 0.18.0
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rich                      13.7.0
rpds-py                   0.13.0
safetensors               0.4.0
scikit-learn              1.3.2
scipy                     1.11.4
SecretStorage             3.3.1
semantic-version          2.10.0
Send2Trash                1.8.2
setuptools                59.6.0
shellingham               1.5.4
six                       1.16.0
sniffio                   1.3.0
soundfile                 0.12.1
soupsieve                 2.5
soxr                      0.3.7
stack-data                0.6.3
starlette                 0.27.0
sympy                     1.12
systemd-python            234
tenacity                  8.2.3
terminado                 0.18.0
threadpoolctl             3.2.0
tinycss2                  1.2.1
tokenizers                0.15.0
tomli                     2.0.1
tomlkit                   0.12.0
toolz                     0.12.0
torch                     2.1.1
torchaudio                2.1.1
torchvision               0.16.1
tornado                   6.3.3
tqdm                      4.66.1
traitlets                 5.13.0
transformers              4.35.2
triton                    2.1.0
typer                     0.9.0
types-python-dateutil     2.8.19.14
typing_extensions         4.8.0
tzdata                    2023.3
ubuntu-advantage-tools    8001
ufw                       0.36.1
unattended-upgrades       0.1
uri-template              1.3.0
urllib3                   2.1.0
uvicorn                   0.24.0.post1
wadllib                   1.3.6
wcwidth                   0.2.10
webcolors                 1.13
webencodings              0.5.1
websocket-client          1.6.4
websockets                11.0.3
wheel                     0.37.1
widgetsnbextension        4.0.9
xxhash                    3.4.1
yarl                      1.9.2
zipp                      1.0.0
blademoon commented 9 months ago

My current hypothesis is that the problem is in the data type returned by the metrics. Metrics with no problem combining returns the dictionary:

m_accuracy = evaluate.load("accuracy")
m_recall = evaluate.load("recall")
m_precision = evaluate.load("precision")
m_f1 = evaluate.load("f1")

pred = [1,0,1]
ref = [1,1,1]

result_accuracy = m_accuracy.compute(predictions=pred, references=ref)
result_recall = m_recall.compute(predictions=pred, references=ref)
result_precision = m_precision.compute(predictions=pred, references=ref)
result_f1 = m_f1.compute(predictions=pred, references=ref)

print("Accuracy:", type(result_accuracy), result_accuracy)
print("Recall:", type(result_recall), result_recall)
print("precision:", type(result_precision), result_precision)
print("F1:", type(result_f1), result_f1)

Output:

Accuracy: <class 'dict'> {'accuracy': 0.6666666666666666}
Recall: <class 'dict'> {'recall': 0.6666666666666666}
precision: <class 'dict'> {'precision': 1.0}
F1: <class 'dict'> {'f1': 0.8}

Metrics that have problems (such as WER, CER) return a value when trying to combine:

m_wer = evaluate.load("wer")
m_cer = evaluate.load("cer")

pred = ["test", "Test"]
ref = ["TEST", "Test"]

result_wer = m_wer.compute(predictions=pred, references=ref)
result_cer = m_cer.compute(predictions=pred, references=ref)

print("WER:", type(result_wer), result_wer)
print("CER:", type(result_cer), result_cer)

Output:

WER: <class 'float'> 0.5
CER: <class 'float'> 0.5

Naturally float has no keys attribute...

blademoon commented 9 months ago

It is possible to make changes to the code of these files:

https://github.com/huggingface/evaluate/blob/main/metrics/wer/wer.py#98 https://github.com/huggingface/evaluate/blob/main/metrics/wer/wer.py#106

https://github.com/blademoon/evaluate/blob/main/metrics/cer/cer.py#140 https://github.com/blademoon/evaluate/blob/main/metrics/cer/cer.py#159

may solve the problem. Hypothetical, you just need to "wrap" the return value in a dictionary with the appropriate key.