tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

TurncatedNormal gives wrong results sometimes #1788

Open Joshuaalbert opened 4 months ago

Joshuaalbert commented 4 months ago

Truncated normal gives wrong values sometimes. Seems to be when the scale is relatively small, but in surprising situations where you'd expect it to work like TruncatedNormal(1, 0.1, 0, 10).

MVCE

import jax
import jax.numpy as jnp
import pytest
import tensorflow_probability.substrates.jax as tfp

tfpd = tfp.distributions

@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
    dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
    u = jnp.linspace(0., 1., 100)

    samples = jax.vmap(dist.quantile)(u)
    assert jnp.all(samples >= low)
    assert jnp.all(samples <= high)
Joshuaalbert commented 4 months ago

6 out of 8 tests fail

========================= 6 failed, 2 passed in 1.72s ==========================
FAILED                     [ 12%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = 10, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
FAILED                      [ 25%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.1])
Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,
       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,
       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,
       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,
       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,
       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,
       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) != 10

<Click to see difference>

low = 0.0, high = 10, scale = 0.1

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
        assert jnp.all(samples >= low)
>       assert jnp.all(samples <= high)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,\n       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,\n       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,\n       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,\n       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,\n       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,\n       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) <= 10)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:18: AssertionError
FAILED                     [ 37%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = 10, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
FAILED                      [ 50%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.1])
Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,
       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,
       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,
       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,
       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,
       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,
       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) != 10

<Click to see difference>

low = 0.0, high = 10, scale = 0.1

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
        assert jnp.all(samples >= low)
>       assert jnp.all(samples <= high)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0.        , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n       0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n       0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n       0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n       0.9165573 , 0.92009175, 0.923529  , 0.9268783 , 0.9301474 ,\n       0.9333436 , 0.936473  , 0.93954146, 0.942554  , 0.9455153 ,\n       0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n       0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n       0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n       0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734  ,\n       1.001266  , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n       1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n       1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n       1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n       1.0544847 , 1.057446  , 1.0604585 , 1.063527  , 1.0666565 ,\n       1.0698526 , 1.0731218 , 1.076471  , 1.0799083 , 1.0834427 ,\n       1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n       1.1073987 , 1.1120205 , 1.116895  , 1.122064  , 1.1275817 ,\n       1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n       1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 ,        inf],      dtype=float32) <= 10)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:18: AssertionError
FAILED                    [ 62%]
debug/error.py:8 (test_truncated_normal[inf-0.00-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = inf, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
PASSED                     [ 75%]FAILED                    [ 87%]
debug/error.py:8 (test_truncated_normal[inf-0.01-0.01])
Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,
       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,
       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,
       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,
       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,
       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) != 0.0

<Click to see difference>

low = 0.0, high = inf, scale = 0.01

    @pytest.mark.parametrize("scale", [0.01, 0.1])
    @pytest.mark.parametrize("low", [0.0, 0.])
    @pytest.mark.parametrize("high", [10, jnp.inf])
    def test_truncated_normal(low, high, scale):
        dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
        u = jnp.linspace(0., 1., 100)

        samples = jax.vmap(dist.quantile)(u)
>       assert jnp.all(samples >= low)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([      -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n       0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n       0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n       0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n       0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n       0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n       0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n       0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n       0.997571  , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n       0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n       1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n       1.0013971 , 1.0016533 , 1.0019106 , 1.002169  , 1.002429  ,\n       1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758  ,\n       1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157  ,\n       1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n       1.0069853 , 1.0073122 , 1.007647  , 1.0079908 , 1.0083443 ,\n       1.0087085 , 1.0090846 , 1.009474  , 1.0098784 , 1.0102996 ,\n       1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n       1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n       1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 ,        inf],      dtype=float32) >= 0.0)
E        +    where <function all at 0x7f09ad5271a0> = jnp.all

error.py:17: AssertionError
PASSED                     [100%]
ColCarroll commented 4 months ago

Hey! Thanks for opening this issue -- it looks like the problem is with the boundaries here, as we might expect

import numpy.testing as npt
import scipy.stats as st

low = 0.0
u = jnp.linspace(0., 1., 100)
for scale in [0.01, 0.1]:
  for high in [10, jnp.inf]:
    rv = st.truncnorm((low - 1.) / scale, (high - 1.) / scale, 1.0, scale)
    dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
    print(scale, low, high)
    print(dist.quantile(jnp.array([0, 1.])), rv.ppf(jnp.array([0, 1.])))
    npt.assert_allclose(dist.quantile(u[1:-1]), rv.ppf(u[1:-1]), atol=1e-7)

Outputs

0.01 0.0 10
[-inf  inf] [ 0. 10.]
0.01 0.0 inf
[-inf  inf] [ 0. inf]
0.1 0.0 10
[ 0. inf] [ 0. 10.]
0.1 0.0 inf
[ 0. inf] [ 0. inf]
Joshuaalbert commented 4 months ago

What's interesting is that if you go to log space, the argument to ndtri(...) in the quantile is finite at both ends. It's just fairly close to infinite. I think following up with a few steps of bisection would solve this, because ndtr is more stable than ndtri. Make sense? WDYT?

Joshuaalbert commented 4 months ago

Or, thinking about this again, perhaps the best would be to clip the output of the quantile to the range, and then define a safe custom gradient rule.