Comments (4)
6 out of 8 tests fail
========================= 6 failed, 2 passed in 1.72s ==========================
FAILED [ 12%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = 10, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
FAILED [ 25%]
debug/error.py:8 (test_truncated_normal[10-0.00-0.1])
Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,
0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,
0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,
1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,
1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,
1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,
1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) != 10
<Click to see difference>
low = 0.0, high = 10, scale = 0.1
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
assert jnp.all(samples >= low)
> assert jnp.all(samples <= high)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n 0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n 0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n 0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n 0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,\n 0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,\n 0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n 0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n 0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n 0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,\n 1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n 1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n 1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n 1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n 1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,\n 1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,\n 1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n 1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,\n 1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n 1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) <= 10)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:18: AssertionError
FAILED [ 37%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = 10, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
FAILED [ 50%]
debug/error.py:8 (test_truncated_normal[10-0.01-0.1])
Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,
0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,
0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,
0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,
0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,
0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,
0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,
0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,
0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,
0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,
1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,
1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,
1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,
1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,
1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,
1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,
1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,
1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,
1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,
1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) != 10
<Click to see difference>
low = 0.0, high = 10, scale = 0.1
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
assert jnp.all(samples >= low)
> assert jnp.all(samples <= high)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([0. , 0.7677425 , 0.79504055, 0.8123641 , 0.8253983 ,\n 0.83600235, 0.8450294 , 0.8529455 , 0.86003435, 0.8664822 ,\n 0.8724183 , 0.87793595, 0.8831051 , 0.8879795 , 0.89260125,\n 0.89700437, 0.90121627, 0.90525985, 0.90915424, 0.91291547,\n 0.9165573 , 0.92009175, 0.923529 , 0.9268783 , 0.9301474 ,\n 0.9333436 , 0.936473 , 0.93954146, 0.942554 , 0.9455153 ,\n 0.9484295 , 0.9513006 , 0.9541321 , 0.9569273 , 0.9596892 ,\n 0.9624207 , 0.9651244 , 0.9678029 , 0.9704585 , 0.9730934 ,\n 0.97570974, 0.9783096 , 0.9808948 , 0.98346734, 0.98602897,\n 0.9885815 , 0.99112654, 0.9936659 , 0.99620116, 0.998734 ,\n 1.001266 , 1.0037988 , 1.0063341 , 1.0088735 , 1.0114186 ,\n 1.0139711 , 1.0165327 , 1.0191052 , 1.0216905 , 1.0242903 ,\n 1.0269066 , 1.0295415 , 1.0321971 , 1.0348755 , 1.0375793 ,\n 1.0403109 , 1.0430727 , 1.0458679 , 1.0486994 , 1.0515704 ,\n 1.0544847 , 1.057446 , 1.0604585 , 1.063527 , 1.0666565 ,\n 1.0698526 , 1.0731218 , 1.076471 , 1.0799083 , 1.0834427 ,\n 1.0870845 , 1.0908458 , 1.0947402 , 1.0987837 , 1.1029956 ,\n 1.1073987 , 1.1120205 , 1.116895 , 1.122064 , 1.1275817 ,\n 1.1335177 , 1.1399657 , 1.1470546 , 1.1549706 , 1.1639977 ,\n 1.1746017 , 1.1876359 , 1.2049594 , 1.2322574 , inf], dtype=float32) <= 10)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:18: AssertionError
FAILED [ 62%]
debug/error.py:8 (test_truncated_normal[inf-0.00-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = inf, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
PASSED [ 75%]FAILED [ 87%]
debug/error.py:8 (test_truncated_normal[inf-0.01-0.01])
Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,
0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,
0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,
0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,
0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,
0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,
0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,
0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,
0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,
0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,
1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,
1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,
1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,
1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,
1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,
1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,
1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,
1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,
1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,
1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) != 0.0
<Click to see difference>
low = 0.0, high = inf, scale = 0.01
@pytest.mark.parametrize("scale", [0.01, 0.1])
@pytest.mark.parametrize("low", [0.0, 0.])
@pytest.mark.parametrize("high", [10, jnp.inf])
def test_truncated_normal(low, high, scale):
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
u = jnp.linspace(0., 1., 100)
samples = jax.vmap(dist.quantile)(u)
> assert jnp.all(samples >= low)
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <function all at 0x7f09ad5271a0>(Array([ -inf, 0.9767743 , 0.97950405, 0.9812364 , 0.98253983,\n 0.98360026, 0.9845029 , 0.9852945 , 0.98600346, 0.9866482 ,\n 0.9872418 , 0.9877936 , 0.9883105 , 0.98879796, 0.98926014,\n 0.98970044, 0.9901216 , 0.99052596, 0.9909154 , 0.9912915 ,\n 0.9916557 , 0.99200916, 0.9923529 , 0.9926878 , 0.99301475,\n 0.99333435, 0.9936473 , 0.9939541 , 0.9942554 , 0.99455154,\n 0.99484295, 0.99513006, 0.9954132 , 0.99569273, 0.99596894,\n 0.99624205, 0.9965125 , 0.9967803 , 0.9970459 , 0.9973093 ,\n 0.997571 , 0.9978309 , 0.9980895 , 0.99834675, 0.9986029 ,\n 0.99885815, 0.99911267, 0.9993666 , 0.99962014, 0.9998734 ,\n 1.0001266 , 1.0003799 , 1.0006334 , 1.0008874 , 1.0011419 ,\n 1.0013971 , 1.0016533 , 1.0019106 , 1.002169 , 1.002429 ,\n 1.0026907 , 1.0029541 , 1.0032197 , 1.0034876 , 1.003758 ,\n 1.0040311 , 1.0043073 , 1.0045868 , 1.0048699 , 1.005157 ,\n 1.0054485 , 1.0057446 , 1.0060458 , 1.0063527 , 1.0066656 ,\n 1.0069853 , 1.0073122 , 1.007647 , 1.0079908 , 1.0083443 ,\n 1.0087085 , 1.0090846 , 1.009474 , 1.0098784 , 1.0102996 ,\n 1.0107399 , 1.0112021 , 1.0116895 , 1.0122064 , 1.0127581 ,\n 1.0133518 , 1.0139966 , 1.0147054 , 1.0154971 , 1.0163997 ,\n 1.0174601 , 1.0187635 , 1.0204959 , 1.0232258 , inf], dtype=float32) >= 0.0)
E + where <function all at 0x7f09ad5271a0> = jnp.all
error.py:17: AssertionError
PASSED [100%]
from probability.
Hey! Thanks for opening this issue -- it looks like the problem is with the boundaries here, as we might expect
import numpy.testing as npt
import scipy.stats as st
low = 0.0
u = jnp.linspace(0., 1., 100)
for scale in [0.01, 0.1]:
for high in [10, jnp.inf]:
rv = st.truncnorm((low - 1.) / scale, (high - 1.) / scale, 1.0, scale)
dist = tfpd.TruncatedNormal(1.0, scale, low=low, high=high)
print(scale, low, high)
print(dist.quantile(jnp.array([0, 1.])), rv.ppf(jnp.array([0, 1.])))
npt.assert_allclose(dist.quantile(u[1:-1]), rv.ppf(u[1:-1]), atol=1e-7)
Outputs
0.01 0.0 10
[-inf inf] [ 0. 10.]
0.01 0.0 inf
[-inf inf] [ 0. inf]
0.1 0.0 10
[ 0. inf] [ 0. 10.]
0.1 0.0 inf
[ 0. inf] [ 0. inf]
from probability.
What's interesting is that if you go to log space, the argument to ndtri(...) in the quantile is finite at both ends. It's just fairly close to infinite. I think following up with a few steps of bisection would solve this, because ndtr is more stable than ndtri. Make sense? WDYT?
from probability.
Or, thinking about this again, perhaps the best would be to clip the output of the quantile to the range, and then define a safe custom gradient rule.
from probability.
Related Issues (20)
- Dirichlet distribution sampling issue when jit_compile=True HOT 1
- AttributeError: 'SymbolicTensor' object has no attribute 'log_prob' when exporting train signature with `IndependentNormal` layer HOT 1
- Add Poisson quantile
- Computing log_prob for tfd.Sample() with a different number of samples
- TruncatedCauchy gives wrong results sometimes
- `_parameter_properties` is not implemented for `LinearGaussianStateSpaceModel`
- tensorflow 2.16.1 breaks tensorflow-probability with Keras `3.0` API HOT 3
- `LinearGaussianStateSpaceModel` filtering initial state is incorrect
- Piecewise distribution
- Keras not accepting character `/` from build_factored_surrogate_posterior HOT 4
- A bug in Linear_Mixed_Effects_Models.ipynb
- Conditional input with multiple flows HOT 1
- mlx backend HOT 1
- Can't jit PoissonLogNormalQuadratureCompound log_prob
- autobnn error HOT 2
- Addition of "location" type parameter in the Gamma distribution HOT 2
- Unexpected Symbolic tensor in Tensorflow Probability tensor_coercible object (mixture layer)
- TFP JAX: The transition kernel drastically decreases speed. HOT 3
- jax.dtypes.prng_key gives `AttributeError: module 'jax.dtypes' has no attribute 'prng_key'` HOT 1
- MultivariateNormalTriL Layer appears to be incompatible with tf.keras in tf 2.16.1 and tfp 0.24 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from probability.