Open Joshuaalbert opened 4 months ago
6 out of 8 tests fail
========================= 6 failed, 2 passed in 1.72s ==========================
FAILED [ 12%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = 10, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
FAILED [ 25%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.1])
Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,
0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,
0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,
1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,
1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,
1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,
1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) != 10
<Click to see difference>
low = 0.0, high = 10, scale = 0.1
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
assert jnp.all(samples >= low)
> assert jnp.all(samples <= high)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n 0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n 0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n 0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n 0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,\n 0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,\n 0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n 0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n 0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n 0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,\n 1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n 1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n 1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n 1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n 1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,\n 1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,\n 1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n 1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,\n 1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n 1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) <= 10)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:18: AssertionError
FAILED [ 37%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = 10, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
FAILED [ 50%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.1])
Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,
0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,
0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,
1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,
1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,
1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,
1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) != 10
<Click to see difference>
low = 0.0, high = 10, scale = 0.1
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
assert jnp.all(samples >= low)
> assert jnp.all(samples <= high)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n 0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n 0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n 0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n 0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,\n 0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,\n 0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n 0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n 0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n 0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,\n 1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n 1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n 1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n 1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n 1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,\n 1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,\n 1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n 1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,\n 1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n 1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) <= 10)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:18: AssertionError
FAILED [ 62%]
debug/error.py:8 (test_truncated_normal[inf-0.00-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = inf, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
PASSED [ 75%]FAILED [ 87%]
debug/error.py:8 (test_truncated_normal[inf-0.01-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = inf, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
PASSED [100%]
Hey! Thanks for opening this issue -- it looks like the problem is with the boundaries here, as we might expect
import numpy.testing as npt
import scipy.stats as st
low = 0.0
u = jnp.linspace(0., 1., 100)
for scale in [0.01, 0.1]:
for high in [10, jnp.inf]:
rv = st.truncnorm((low - 1.) / scale, (high - 1.) / scale, 1.0, scale)
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
print(scale, low, high)
print(dist.quantile(jnp.array([0, 1.])), rv.ppf(jnp.array([0, 1.])))
npt.assert_allclose(dist.quantile(u[1:-1]), rv.ppf(u[1:-1]), atol=1e-7)
Outputs
0.01 0.0 10
[-inf inf] [ 0. 10.]
0.01 0.0 inf
[-inf inf] [ 0. inf]
0.1 0.0 10
[ 0. inf] [ 0. 10.]
0.1 0.0 inf
[ 0. inf] [ 0. inf]
What's interesting is that if you go to log space, the argument to ndtri(...) in the quantile is finite at both ends. It's just fairly close to infinite. I think following up with a few steps of bisection would solve this, because ndtr is more stable than ndtri. Make sense? WDYT?
Or, thinking about this again, perhaps the best would be to clip the output of the quantile to the range, and then define a safe custom gradient rule.
Truncated normal gives wrong values sometimes. Seems to be when the scale is relatively small, but in surprising situations where you'd expect it to work like TruncatedNormal(1, 0.1, 0, 10).
MVCE