Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 409 forks source link

PESQ No utterances detected #2752

Closed veera-puthiran-14082 closed 1 month ago

veera-puthiran-14082 commented 2 months ago

🐛 Bug

When calculating pesq for a batch, if an exception is thrown for a single audio in that batch, error thrown for entire batch.

To Reproduce

Steps to reproduce the behavior...

Code sample & error trace ```python # Ideally attach a minimal code sample to reproduce the decried issue. # Minimal means having the shortest code but still preserving the bug. import torch from torchmetrics.audio import PerceptualEvaluationSpeechQuality torch.manual_seed(42) torch.cuda.manual_seed_all(42) if __name__ == "__main__": preds = torch.randn(32, 16000*10) target = torch.randn(32, 16000*10) wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb',2) print(wb_pesq(preds, target)) ``` - Error trace ``` Traceback (most recent call last): File "/data/veera/zspeech/temp/sample.py", line 11, in print(wb_pesq(preds, target)) File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torchmetrics/metric.py", line 311, in forward self._forward_cache = self._forward_reduce_state_update(*args, **kwargs) File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torchmetrics/metric.py", line 380, in _forward_reduce_state_update self.update(*args, **kwargs) File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torchmetrics/metric.py", line 482, in wrapped_func update(*args, **kwargs) File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torchmetrics/audio/pesq.py", line 124, in update pesq_batch = perceptual_evaluation_speech_quality( File "/home/veera/anaconda3/envs/zspeech/lib/python3.10/site-packages/torchmetrics/functional/audio/pesq.py", line 108, in perceptual_evaluation_speech_quality pesq_val = torch.from_numpy(pesq_val_np) TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool. ``` - Variable `pesq_val_np` holds the value as below ``` [1.5404247045516968 1.4244357347488403 1.7014210224151611 1.3647973537445068 1.1401857137680054 1.4725890159606934 2.0076751708984375 1.088727355003357 1.5617468357086182 1.2922636270523071 1.2720121145248413 1.284945011138916 1.2110638618469238 1.5236833095550537 1.9427233934402466 1.3616420030593872 1.1479936838150024 1.2334833145141602 2.5019659996032715 1.1108096837997437 1.4132858514785767 1.6151965856552124 3.488539218902588 1.12986159324646 NoUtterancesError(b'No utterances detected') 1.1446586847305298 1.2475643157958984 1.3870091438293457 1.4756184816360474 1.3193098306655884 1.5343947410583496 1.2668492794036865] ``` - when converting the above the `pesq_val_np` to tensor, error is thrown. (above mentioned error trace) - either it can be replaced with 0 or with average value.

Expected behavior

Environment

Additional context

Borda commented 2 months ago
  • TorchMetrics version 2.0.0

are you sure?

veera-puthiran-14082 commented 2 months ago

My bad, apologies. updated now. the torchmetric in my environment was version 1.4.0.post0