Closed Zhanwei-Liu closed 4 weeks ago
Hey, thanks for the PR. Can you explain the context a little more? In what situation does var
return inf
?
Certainly! The variance function, np.var
, calculates the average of the squared differences from the mean. When dealing with extremely large values, these squared differences can become so large that they exceed the maximum representable floating-point number in Python, resulting in an infinite value (inf
). For example:
import numpy as np
# Example with large values that cause overflow
data = np.array([0, 1e308, -1e308])
# Calculate variance
variance = np.var(data)
print(variance) # Output will be inf
I am encountering this problem, but I suppose the program can raise a RuntimeWarning
instead of stopping.
Thanks for the clarification. Such high values mean that something isn't working as expected at some point, right? I can't imagine a scenario in which you would absolutely need such values.
That said, explained_variance
supports this and currently returns nan
after a warning, which I find useful.
>>> import numpy as np
>>> from stable_baselines3.common.utils import explained_variance
>>> explained_variance(np.ones(3), np.array([0, 1e308, -1e308]))
/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/numpy/core/_methods.py:176: RuntimeWarning: overflow encountered in multiply
x = um.multiply(x, x, out=x)
/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/utils.py:65: RuntimeWarning: invalid value encountered in scalar divide
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
nan
Implementing the suggested change would just remove that warning, which I think is justified. wdyt @Zhanwei-Liu?
Thank you for your explanation. I have identified the cause of the FloatingPointError
in my program. By adding the following line at the beginning of my script:
np.seterr(invalid='raise')
I mistakenly believed this error was related to stable_baselines3
. The code snippet below can reproduce the error:
import numpy as np
from stable_baselines3.common.utils import explained_variance
np.seterr(invalid='raise')
explained_variance(np.ones(3), np.array([0, 1e308, -1e308]))
Therefore, I believe stable_baselines3
is functioning correctly and does not require modification. You can close this pull request. Thank you very much.
Description
I updated the explained_variance function to handle cases where the variance of y_true is infinity. The function now returns np.nan if the variance of y_true is either 0 or infinity, preventing potential issues with division by infinity.
Motivation and Context
This change is necessary to address an issue where the variance of y_true could be infinity, leading to an invalid calculation of the explained variance. By returning np.nan in such cases, we ensure the function behaves predictably and avoids errors related to infinite values.
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line