DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.35k stars 1.6k forks source link

Fix variance issue where ypred explains y as infinity #1942

Closed Zhanwei-Liu closed 4 weeks ago

Zhanwei-Liu commented 1 month ago

Description

I updated the explained_variance function to handle cases where the variance of y_true is infinity. The function now returns np.nan if the variance of y_true is either 0 or infinity, preventing potential issues with division by infinity.

Motivation and Context

This change is necessary to address an issue where the variance of y_true could be infinity, leading to an invalid calculation of the explained variance. By returning np.nan in such cases, we ensure the function behaves predictably and avoids errors related to infinite values.

Types of changes

Checklist

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

qgallouedec commented 1 month ago

Hey, thanks for the PR. Can you explain the context a little more? In what situation does var return inf?

Zhanwei-Liu commented 1 month ago

Certainly! The variance function, np.var, calculates the average of the squared differences from the mean. When dealing with extremely large values, these squared differences can become so large that they exceed the maximum representable floating-point number in Python, resulting in an infinite value (inf). For example:

import numpy as np

# Example with large values that cause overflow
data = np.array([0, 1e308, -1e308])

# Calculate variance
variance = np.var(data)

print(variance)  # Output will be inf

I am encountering this problem, but I suppose the program can raise a RuntimeWarning instead of stopping.

qgallouedec commented 1 month ago

Thanks for the clarification. Such high values mean that something isn't working as expected at some point, right? I can't imagine a scenario in which you would absolutely need such values.

That said, explained_variance supports this and currently returns nan after a warning, which I find useful.

>>> import numpy as np
>>> from stable_baselines3.common.utils import explained_variance
>>> explained_variance(np.ones(3), np.array([0, 1e308, -1e308]))
/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/numpy/core/_methods.py:176: RuntimeWarning: overflow encountered in multiply
  x = um.multiply(x, x, out=x)
/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/utils.py:65: RuntimeWarning: invalid value encountered in scalar divide
  return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
nan

Implementing the suggested change would just remove that warning, which I think is justified. wdyt @Zhanwei-Liu?

Zhanwei-Liu commented 4 weeks ago

Thank you for your explanation. I have identified the cause of the FloatingPointError in my program. By adding the following line at the beginning of my script:

np.seterr(invalid='raise')

I mistakenly believed this error was related to stable_baselines3. The code snippet below can reproduce the error:

import numpy as np
from stable_baselines3.common.utils import explained_variance

np.seterr(invalid='raise')
explained_variance(np.ones(3), np.array([0, 1e308, -1e308]))

Therefore, I believe stable_baselines3 is functioning correctly and does not require modification. You can close this pull request. Thank you very much.