TheAlgorithms / Python

All Algorithms implemented in Python
https://thealgorithms.github.io/Python/
MIT License
193.42k stars 45.51k forks source link

Avoid log(0) in KL divergence #12233

Open kevin1kevin1k opened 4 days ago

kevin1kevin1k commented 4 days ago

Repository commit

03a42510b01c574292ca9c6525cbf0572ff5a2a5

Python version (python --version)

Python 3.10.15

Dependencies version (pip freeze)

absl-py==2.1.0 astunparse==1.6.3 beautifulsoup4==4.12.3 certifi==2024.8.30 charset-normalizer==3.4.0 contourpy==1.3.0 cycler==0.12.1 dill==0.3.9 dom_toml==2.0.0 domdf-python-tools==3.9.0 fake-useragent==1.5.1 flatbuffers==24.3.25 fonttools==4.54.1 gast==0.6.0 google-pasta==0.2.0 grpcio==1.67.0 h5py==3.12.1 idna==3.10 imageio==2.36.0 joblib==1.4.2 keras==3.6.0 kiwisolver==1.4.7 libclang==18.1.1 lxml==5.3.0 Markdown==3.7 markdown-it-py==3.0.0 MarkupSafe==3.0.2 matplotlib==3.9.2 mdurl==0.1.2 ml-dtypes==0.3.2 mpmath==1.3.0 namex==0.0.8 natsort==8.4.0 numpy==1.26.4 oauthlib==3.2.2 opencv-python==4.10.0.84 opt_einsum==3.4.0 optree==0.13.0 packaging==24.1 pandas==2.2.3 patsy==0.5.6 pbr==6.1.0 pillow==11.0.0 pip==24.2 protobuf==4.25.5 psutil==6.1.0 Pygments==2.18.0 pyparsing==3.2.0 python-dateutil==2.9.0.post0 pytz==2024.2 qiskit==1.2.4 qiskit-aer==0.15.1 requests==2.32.3 requests-oauthlib==1.3.1 rich==13.9.2 rustworkx==0.15.1 scikit-learn==1.5.2 scipy==1.14.1 setuptools==74.1.2 six==1.16.0 soupsieve==2.6 sphinx-pyproject==0.3.0 statsmodels==0.14.4 stevedore==5.3.0 symengine==0.13.0 sympy==1.13.3 tensorboard==2.16.2 tensorboard-data-server==0.7.2 tensorflow==2.16.2 tensorflow-io-gcs-filesystem==0.37.1 termcolor==2.5.0 threadpoolctl==3.5.0 tomli==2.0.2 tweepy==4.14.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.2.3 Werkzeug==3.0.4 wheel==0.44.0 wrapt==1.16.0 xgboost==2.1.1

Expected behavior

The entries where y_true is 0 should be ignored in the summation (see Actual behavior)

Actual behavior

In https://github.com/TheAlgorithms/Python/blob/03a42510b01c574292ca9c6525cbf0572ff5a2a5/machine_learning/loss_functions.py#L662-L663 if any entry of y_true is 0, the output of np.log would become -inf and thus the method returns nan. Maybe it would be better to exclude those entries where y_true is 0?

bz-e commented 3 days ago

I would like to work on this.

vedprakash226 commented 2 days ago

if y_true is 0 than what we have to return

kevin1kevin1k commented 2 days ago

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

brambhattabhishek commented 2 days ago

1.# Ensure that y_pred doesn't have zero values to avoid division by zero 2.# Clip y_pred to a small positive value to avoid log(0) 3.# Calculate the KL divergence only for non-zero y_true entries

def kl_divergence(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-10, None) to avoid log(0)
    kl_loss = np.where(y_true != 0, y_true * np.log(y_true / y_pred), 0)
    return np.sum(kl_loss)
brambhattabhishek commented 1 day ago

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

is my solution is correct ?

brambhattabhishek commented 1 day ago

/assign