hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,311 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import check_arraylike, promote_args_inexact
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
|
||||
class ModeResult(NamedTuple):
|
||||
mode: Array
|
||||
count: Array
|
||||
|
||||
|
||||
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
|
||||
"""Compute the mode (most common value) along an axis of an array.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.mode`.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
axis: int, default=0. Axis along which to compute the mode.
|
||||
nan_policy: str. JAX only supports ``"propagate"``.
|
||||
keepdims: bool, default=False. If true, reduced axes are left in the result
|
||||
with size 1.
|
||||
|
||||
Returns:
|
||||
A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values,
|
||||
and ``count`` is the number of times each value appears in the input array.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
|
||||
>>> mode, count = jax.scipy.stats.mode(x)
|
||||
>>> mode, count
|
||||
(Array(4, dtype=int32), Array(3, dtype=int32))
|
||||
|
||||
For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode``
|
||||
and the corresponding ``count`` along ``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
... [3, 1, 3, 2, 1, 3],
|
||||
... [1, 2, 2, 3, 1, 2]])
|
||||
>>> mode, count = jax.scipy.stats.mode(x1)
|
||||
>>> mode, count
|
||||
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
|
||||
|
||||
If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``.
|
||||
|
||||
>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
|
||||
>>> mode, count
|
||||
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
|
||||
|
||||
By default, ``jax.scipy.stats.mode`` reduces the dimension of the result.
|
||||
To keep the dimensions same as that of the input array, the argument
|
||||
``keepdims`` must be set to ``True``.
|
||||
|
||||
>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
|
||||
>>> mode, count
|
||||
(Array([[1],
|
||||
[3],
|
||||
[2]], dtype=int32), Array([[3],
|
||||
[3],
|
||||
[3]], dtype=int32))
|
||||
"""
|
||||
check_arraylike("mode", a)
|
||||
x = jnp.atleast_1d(a)
|
||||
|
||||
if nan_policy not in ["propagate", "omit", "raise"]:
|
||||
raise ValueError(
|
||||
f"Illegal nan_policy value {nan_policy!r}; expected one of "
|
||||
"{'propagate', 'omit', 'raise'}"
|
||||
)
|
||||
if nan_policy == "omit":
|
||||
# TODO: return answer without nans included.
|
||||
raise NotImplementedError(
|
||||
f"Logic for `nan_policy` of {nan_policy} is not implemented"
|
||||
)
|
||||
if nan_policy == "raise":
|
||||
raise NotImplementedError(
|
||||
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
|
||||
"Please check if nans exist in `x` outside of the `mode` function."
|
||||
)
|
||||
if axis is not None:
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
|
||||
input_shape = x.shape
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
output_shape = tuple(1 for i in input_shape)
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
|
||||
else:
|
||||
if axis is None:
|
||||
output_shape = ()
|
||||
else:
|
||||
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
|
||||
|
||||
if axis is None:
|
||||
axis = 0
|
||||
x = x.ravel()
|
||||
|
||||
def _mode_helper(x: Array) -> tuple[Array, Array]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return (jnp.array(np.nan, dtype=dtypes.default_float_dtype()),
|
||||
jnp.array(0, dtype=dtypes.default_float_dtype()))
|
||||
else:
|
||||
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
||||
return vals[jnp.argmax(counts)], counts.max()
|
||||
|
||||
x = jnp.moveaxis(x, axis, 0)
|
||||
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
|
||||
vals, counts = api.vmap(_mode_helper, in_axes=1)(x)
|
||||
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
|
||||
|
||||
def invert_permutation(i: Array) -> Array:
|
||||
"""Helper function that inverts a permutation array."""
|
||||
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
|
||||
|
||||
|
||||
@api.jit(static_argnames=["method", "axis", "nan_policy"])
|
||||
def rankdata(
|
||||
a: ArrayLike,
|
||||
method: str = "average",
|
||||
*,
|
||||
axis: int | None = None,
|
||||
nan_policy: str = "propagate",
|
||||
) -> Array:
|
||||
"""Compute the rank of data along an array axis.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.rankdata`.
|
||||
|
||||
Ranks begin at 1, and the *method* argument controls how ties are handled.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
method: str, default="average". Supported methods are
|
||||
``("average", "min", "max", "dense", "ordinal")``
|
||||
For details, see the :func:`scipy.stats.rankdata` documentation.
|
||||
axis: optional integer. If not specified, the input array is flattened.
|
||||
nan_policy: str, JAX's implementation only supports ``"propagate"``.
|
||||
|
||||
Returns:
|
||||
array of ranks along the specified axis.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jnp.array([10, 30, 20])
|
||||
>>> rankdata(x)
|
||||
Array([1., 3., 2.], dtype=float32)
|
||||
|
||||
>>> x = jnp.array([1, 3, 2, 3])
|
||||
>>> rankdata(x)
|
||||
Array([1. , 3.5, 2. , 3.5], dtype=float32)
|
||||
"""
|
||||
check_arraylike("rankdata", a)
|
||||
|
||||
if nan_policy not in ["propagate", "omit", "raise"]:
|
||||
raise ValueError(
|
||||
f"Illegal nan_policy value {nan_policy!r}; expected one of "
|
||||
"{'propagate', 'omit', 'raise'}"
|
||||
)
|
||||
if nan_policy == "omit":
|
||||
raise NotImplementedError(
|
||||
f"Logic for `nan_policy` of {nan_policy} is not implemented"
|
||||
)
|
||||
if nan_policy == "raise":
|
||||
raise NotImplementedError(
|
||||
"In order to best JIT compile `rankdata`, we cannot know whether `x` "
|
||||
"contains nans. Please check if nans exist in `x` outside of the "
|
||||
"`rankdata` function."
|
||||
)
|
||||
|
||||
if method not in ("average", "min", "max", "dense", "ordinal"):
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
if axis is not None:
|
||||
return jnp.apply_along_axis(rankdata, axis, a, method)
|
||||
|
||||
a = jnp.ravel(a)
|
||||
out_dtype = dtypes.default_float_dtype()
|
||||
|
||||
def _rankdata(a: Array) -> Array:
|
||||
arr, sorter = lax.sort_key_val(a, jnp.arange(a.size))
|
||||
inv = invert_permutation(sorter)
|
||||
|
||||
if method == "ordinal":
|
||||
return (inv + 1).astype(out_dtype)
|
||||
obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]])
|
||||
dense = obs.cumsum()[inv]
|
||||
if method == "dense":
|
||||
return dense.astype(out_dtype)
|
||||
count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0].astype(out_dtype)
|
||||
if method == "max":
|
||||
return count[dense]
|
||||
if method == "min":
|
||||
return count[dense - 1] + 1
|
||||
if method == "average":
|
||||
return .5 * (count[dense] + count[dense - 1] + 1)
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
return lax.cond(jnp.any(jnp.isnan(a)),
|
||||
lambda a: jnp.full_like(a, jnp.nan, out_dtype),
|
||||
_rankdata, a)
|
||||
|
||||
|
||||
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
|
||||
"""Compute the standard error of the mean.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.sem`.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
axis: optional integer. If not specified, the input array is flattened.
|
||||
ddof: integer, default=1. The degrees of freedom in the SEM computation.
|
||||
nan_policy: str, default="propagate". JAX supports only "propagate" and
|
||||
"omit".
|
||||
keepdims: bool, default=False. If true, reduced axes are left in the result
|
||||
with size 1.
|
||||
|
||||
Returns:
|
||||
array
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x)
|
||||
Array(0.41, dtype=float32)
|
||||
|
||||
For multi dimensional arrays, ``sem`` computes standard error of mean along
|
||||
``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
... [3, 1, 3, 2, 1, 3],
|
||||
... [1, 2, 2, 3, 1, 2]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1)
|
||||
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
|
||||
|
||||
If ``axis=1``, standard error of mean will be computed along ``axis 1``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1)
|
||||
Array([0.33, 0.4 , 0.31], dtype=float32)
|
||||
|
||||
If ``axis=None``, standard error of mean will be computed along all the axes.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=None)
|
||||
Array(0.2, dtype=float32)
|
||||
|
||||
By default, ``sem`` reduces the dimension of the result. To keep the
|
||||
dimensions same as that of the input array, the argument ``keepdims`` must
|
||||
be set to ``True``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1, keepdims=True)
|
||||
Array([[0.33],
|
||||
[0.4 ],
|
||||
[0.31]], dtype=float32)
|
||||
|
||||
Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan``
|
||||
values in the result.
|
||||
|
||||
>>> nan = np.nan
|
||||
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
|
||||
... [4, 5, 4, 3, nan, 1],
|
||||
... [7, nan, 8, 7, 9, nan]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2)
|
||||
Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
|
||||
|
||||
If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error
|
||||
for the remaining values along the specified axis.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2, nan_policy='omit')
|
||||
Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
|
||||
"""
|
||||
b, = promote_args_inexact("sem", a)
|
||||
if nan_policy == "propagate":
|
||||
size = b.size if axis is None else b.shape[axis]
|
||||
return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype)
|
||||
elif nan_policy == "omit":
|
||||
count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims)
|
||||
return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype)
|
||||
else:
|
||||
raise ValueError(f"{nan_policy} is not supported")
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Bernoulli log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf``
|
||||
|
||||
The Bernoulli probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = \begin{cases}
|
||||
1 - p, & k = 0 \\
|
||||
p, & k = 1 \\
|
||||
0, & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of logpmf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
one = _lax_const(k, 1)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
|
||||
-np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Bernoulli probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf``
|
||||
|
||||
The Bernoulli probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = \begin{cases}
|
||||
1 - p, & k = 0 \\
|
||||
p, & k = 1 \\
|
||||
0, & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
|
||||
|
||||
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Bernoulli cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``cdf``
|
||||
|
||||
The Bernoulli cumulative distribution function is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
|
||||
|
||||
where :math:`f_{pmf}(k, p)` is the Bernoulli probability mass function
|
||||
:func:`jax.scipy.stats.bernoulli.pmf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the CDF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
k, p = promote_args_inexact('bernoulli.cdf', k, p)
|
||||
zero, one = _lax_const(k, 0), _lax_const(k, 1)
|
||||
conds = [
|
||||
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
|
||||
lax.lt(k, zero),
|
||||
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
|
||||
lax.ge(k, one)
|
||||
]
|
||||
vals = [jnp.nan, zero, one - p, one]
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
|
||||
"""Bernoulli percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``ppf``
|
||||
|
||||
The percent point function is the inverse of the cumulative
|
||||
distribution function, :func:`jax.scipy.stats.bernoulli.cdf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PPF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
"""
|
||||
q, p = promote_args_inexact('bernoulli.ppf', q, p)
|
||||
zero, one = _lax_const(q, 0), _lax_const(q, 1)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
|
||||
jnp.nan,
|
||||
jnp.where(lax.le(q, one - p), zero, one)
|
||||
)
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import betaln, betainc, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logpdf``.
|
||||
|
||||
The pdf of the beta function is:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function,
|
||||
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(a, 0)
|
||||
shape_term = lax.neg(betaln(a, b))
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
|
||||
xlog1py(lax.sub(b, one), lax.neg(y)))
|
||||
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
|
||||
result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
|
||||
lax.lt(x, loc)), -np.inf, log_probs)
|
||||
result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)),
|
||||
lax.le(scale, zero)), np.nan, result)
|
||||
return result_positive_constants
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``pdf``.
|
||||
|
||||
The pdf of the beta function is:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta cumulative distribution function
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the beta distribution probability density function,
|
||||
:func:`jax.scipy.stats.beta.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
a,
|
||||
b,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the beta distribution probability density function,
|
||||
:func:`jax.scipy.stats.beta.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
|
||||
|
||||
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
|
||||
:func:`jax.scipy.stats.beta.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
b,
|
||||
a,
|
||||
1 - lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
|
||||
|
||||
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
|
||||
:func:`jax.scipy.stats.beta.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, a, b, loc, scale))
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import betaln
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@api.jit
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
r"""Beta-binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.betabinom` ``logpmf``
|
||||
|
||||
The beta-binomial distribution's probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
|
||||
|
||||
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
|
||||
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.betabinom.pmf`
|
||||
"""
|
||||
k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
|
||||
y = lax.sub(lax.floor(k), loc)
|
||||
one = _lax_const(y, 1)
|
||||
zero = _lax_const(y, 0)
|
||||
combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one))))
|
||||
beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b))
|
||||
log_probs = lax.add(combiln, beta_lns)
|
||||
log_probs = jnp.where(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), 0., log_probs)
|
||||
y_cond = jnp.logical_or(jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, n)),
|
||||
lax.le(lax.add(y, a), zero))
|
||||
log_probs = jnp.where(y_cond, -np.inf, log_probs)
|
||||
n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, zero), lax.le(a, zero)), lax.le(b, zero))
|
||||
return jnp.where(n_a_b_cond, np.nan, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
r"""Beta-binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.betabinom` ``pmf``.
|
||||
|
||||
The beta-binomial distribution's probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
|
||||
|
||||
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
|
||||
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.betabinom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, a, b, loc))
|
||||
@@ -0,0 +1,90 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.scipy.special import gammaln, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.binom` ``logpmf``.
|
||||
|
||||
The binomial probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
|
||||
|
||||
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.binom.pmf`
|
||||
"""
|
||||
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
|
||||
y = lax.sub(k, loc)
|
||||
zero = _lax_const(y, 0)
|
||||
comb_term = lax.sub(
|
||||
gammaln(n + 1),
|
||||
lax.add(gammaln(y + 1), gammaln(n - y + 1))
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
y_n_cond = jnp.logical_or(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)),
|
||||
lax.eq(log_linear_term, zero))
|
||||
log_probs = jnp.where(y_n_cond, 0., log_probs)
|
||||
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -np.inf)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.binom` ``pmf``.
|
||||
|
||||
The binomial probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
|
||||
|
||||
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.binom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
@@ -0,0 +1,293 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.ufuncs import arctan
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logpdf``.
|
||||
|
||||
The Cauchy probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\pi(1 + x^2)}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
normalize_term = lax.log(lax.mul(pi, scale))
|
||||
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``pdf``.
|
||||
|
||||
The Cauchy probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\pi(1 + x^2)}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logcdf``
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.cauchy.sf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammainc, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logpdf``.
|
||||
|
||||
The chi-square probability distribution function is given by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, k) = \begin{cases}
|
||||
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
|
||||
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
two = _lax_const(x, 2)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
df_on_two = lax.div(df, two)
|
||||
|
||||
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
||||
|
||||
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
||||
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``pdf``.
|
||||
|
||||
The chi-square probability distribution function is given by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, k) = \begin{cases}
|
||||
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
|
||||
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammainc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, np.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, df, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammaincc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, np.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, df, loc, scale))
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def _is_simplex(x: Array) -> Array:
|
||||
x_sum = jnp.sum(x, axis=0)
|
||||
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
r"""Dirichlet log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``.
|
||||
|
||||
The Dirichlet probability density function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
|
||||
|
||||
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
|
||||
in a :math:`K`-dimensional vector space.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
alpha: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.dirichlet.pdf`
|
||||
"""
|
||||
return _logpdf(*promote_dtypes_inexact(x, alpha))
|
||||
|
||||
def _logpdf(x: Array, alpha: Array) -> Array:
|
||||
if alpha.ndim != 1:
|
||||
raise ValueError(
|
||||
f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
|
||||
)
|
||||
if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
|
||||
raise ValueError(
|
||||
"`x` must have either the same number of entries as `alpha` "
|
||||
f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
|
||||
)
|
||||
one = _lax_const(x, 1)
|
||||
if x.shape[0] != alpha.shape[0]:
|
||||
x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
|
||||
normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
|
||||
if x.ndim > 1:
|
||||
alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
|
||||
log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
|
||||
return jnp.where(_is_simplex(x), log_probs, -np.inf)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
r"""Dirichlet probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``.
|
||||
|
||||
The Dirichlet probability density function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
|
||||
|
||||
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
|
||||
in a :math:`K`-dimensional vector space.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
alpha: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.dirichlet.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, alpha))
|
||||
@@ -0,0 +1,270 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logpdf``.
|
||||
|
||||
The Exponential probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
e^{-x} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||
log_scale = lax.log(scale)
|
||||
linear_term = lax.div(lax.sub(x, loc), scale)
|
||||
log_probs = lax.neg(lax.add(linear_term, log_scale))
|
||||
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``pdf``.
|
||||
|
||||
The Exponential probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
e^{-x} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
|
||||
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
|
||||
return jnp.where(
|
||||
lax.lt(x, loc),
|
||||
jnp.zeros_like(neg_scaled_x),
|
||||
lax.neg(lax.expm1(neg_scaled_x)),
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.log1p(lax.neg(sf(x, loc, scale)))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
|
||||
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
|
||||
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logsf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
|
||||
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
np.nan,
|
||||
lax.neg(lax.log1p(neg_scaled_q)),
|
||||
)
|
||||
@@ -0,0 +1,237 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy, gammainc, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logpdf``.
|
||||
|
||||
The Gamma probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
|
||||
|
||||
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`x \ge 0` and :math:`a > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
|
||||
ok = lax.ge(x, loc)
|
||||
one = _lax_const(x, 1)
|
||||
y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
|
||||
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
|
||||
shape_terms = lax.add(gammaln(a), lax.log(scale))
|
||||
log_probs = lax.sub(log_linear_term, shape_terms)
|
||||
return jnp.where(ok, log_probs, -np.inf)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
|
||||
|
||||
The Gamma probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
|
||||
|
||||
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`x \ge 0` and :math:`a > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gamma.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
|
||||
return gammainc(
|
||||
a,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, np.inf),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gamma.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, a, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.gamma.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.gamma.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, a, loc, scale))
|
||||
@@ -0,0 +1,103 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``logpdf``.
|
||||
|
||||
The generalized normal probability distribution function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
|
||||
:math:`\beta > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.cdf`
|
||||
- :func:`jax.scipy.stats.gennorm.pdf`
|
||||
"""
|
||||
x, beta = promote_args_inexact("gennorm.logpdf", x, beta)
|
||||
return lax.log(.5 * beta) - lax.lgamma(1/beta) - lax.abs(x)**beta
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gennorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.pdf`
|
||||
- :func:`jax.scipy.stats.gennorm.logpdf`
|
||||
"""
|
||||
x, beta = promote_args_inexact("gennorm.cdf", x, beta)
|
||||
return .5 * (1 + lax.sign(x) * lax.igamma(1/beta, lax.abs(x)**beta))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``pdf``.
|
||||
|
||||
The generalized normal probability distribution function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
|
||||
:math:`\beta > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.cdf`
|
||||
- :func:`jax.scipy.stats.gennorm.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, beta))
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Geometric log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.geom` ``logpmf``.
|
||||
|
||||
The Geometric probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = (1 - p)^{k-1}p
|
||||
|
||||
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.geom.pmf`
|
||||
"""
|
||||
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
one = _lax_const(k, 1)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
|
||||
return jnp.where(lax.le(x, zero), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Geometric probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.geom` ``pmf``.
|
||||
|
||||
The Geometric probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = (1 - p)^{k-1}p
|
||||
|
||||
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.geom.logpmf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
@@ -0,0 +1,256 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.scipy.special import xlogy, xlog1py
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(pdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logpdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logpdf = -log(scale) + z - exp(z)
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
neg_log_scale = xlogy(-1, scale)
|
||||
t2 = lax.sub(z, lax.exp(z))
|
||||
log_pdf = lax.add(neg_log_scale, t2)
|
||||
return jnp.where(ok, log_pdf, np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``pdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate pdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logcdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(cdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
neg_exp_z = lax.neg(lax.exp(z))
|
||||
# xlog1p fails here, that's why log1p is used here
|
||||
# even log1p fails for some cases when using float64 mode
|
||||
# so we're using this formula which is stable
|
||||
log_cdf = lax.log(-lax.expm1(neg_exp_z))
|
||||
return jnp.where(ok, log_cdf, np.nan)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``cdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate cdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) percent point function (inverse of CDF)
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``ppf``.
|
||||
|
||||
.. math::
|
||||
|
||||
F_{ppf}}(p; \mu, \beta) = \mu + \beta \log\left( -\log(1 - p) \right)
|
||||
|
||||
Args:
|
||||
p: ArrayLike, probability value (quantile) at which to evaluate ppf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
p, loc, scale = promote_args_inexact("gumbel_l.ppf", p, loc, scale)
|
||||
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
|
||||
lax.lt(p, _lax_const(p, 1)))
|
||||
# quantile = loc + (scale)*log(-log(1 - p))
|
||||
t1 = xlog1py(-1, lax.neg(p))
|
||||
# xlogp failed here too, that's why log is used
|
||||
t = lax.mul(scale, lax.log(t1))
|
||||
quantile = lax.add(loc, t)
|
||||
return jnp.where(ok, quantile, np.nan)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``sf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of sf values (1 - cdf)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
"""
|
||||
return jnp.exp(logsf(x, loc, scale))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logsf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logsf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logsf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logsf = -exp(z)
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
log_sf = lax.neg(lax.exp(z))
|
||||
return jnp.where(ok, log_sf, np.nan)
|
||||
@@ -0,0 +1,257 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.scipy.special import xlogy
|
||||
from jax._src.nn.functions import log1mexp
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(pdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logpdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
# logpdf = -log(beta) - (z + exp(-z))
|
||||
neg_log_scale = xlogy(-1, scale)
|
||||
t2 = lax.neg(lax.add(z, lax.exp(lax.neg(z))))
|
||||
log_pdf = lax.add(neg_log_scale, t2)
|
||||
return jnp.where(ok, log_pdf, np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``pdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate pdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logcdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(cdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logcdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
# log cdf = -exp(-z)
|
||||
log_cdf = lax.neg(lax.exp(lax.neg(z)))
|
||||
return jnp.where(ok, log_cdf, np.nan)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``cdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate cdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``ppf``.
|
||||
|
||||
.. math::
|
||||
|
||||
F(p; \mu, \beta) = \mu - \beta \log\left( -\log(p) \right)
|
||||
|
||||
Args:
|
||||
p: ArrayLike, probability value (quantile) at which to evaluate ppf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
p, loc, scale = promote_args_inexact("gumbel_r.ppf", p, loc, scale)
|
||||
# 0 < p < 1
|
||||
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
|
||||
lax.lt(p, _lax_const(p, 1)))
|
||||
|
||||
# quantile = loc - (scale)*log(-log(p))
|
||||
t1 = xlogy(-1, p)
|
||||
t = lax.mul(scale, lax.log(t1))
|
||||
quantile = lax.sub(loc, t)
|
||||
return jnp.where(ok, quantile, np.nan)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``sf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - F_{cdf}(x; \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of sf values (1 - cdf)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.sf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# sf = 1 - exp(-exp(-z))
|
||||
neg_z = lax.div(lax.sub(loc, x), scale)
|
||||
t1 = lax.exp(lax.neg(lax.exp(neg_z)))
|
||||
_sf = lax.sub(_lax_const(x, 1), t1)
|
||||
return jnp.where(ok, _sf, np.nan)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logsf``.
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logsf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logsf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logsf = log(1 - exp(-exp(-z)))
|
||||
neg_z = lax.div(lax.sub(loc, x), scale)
|
||||
log_sf = log1mexp(lax.exp(neg_z))
|
||||
return jnp.where(ok, log_sf, np.nan)
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import random
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.scipy import linalg, special
|
||||
from jax._src.tree_util import register_pytree_node_class
|
||||
from jax._src.typing import Array
|
||||
|
||||
BwMethod = None | str | Array | Callable[[Any], Array]
|
||||
|
||||
@register_pytree_node_class
|
||||
@dataclass(frozen=True, init=False)
|
||||
class gaussian_kde:
|
||||
"""Gaussian Kernel Density Estimator
|
||||
|
||||
JAX implementation of :class:`scipy.stats.gaussian_kde`.
|
||||
|
||||
Parameters:
|
||||
dataset: arraylike, real-valued. Data from which to estimate the distribution.
|
||||
If 1D, shape is (n_data,). If 2D, shape is (n_dimensions, n_data).
|
||||
bw_method: string, scalar, or callable. Either "scott", "silverman", a scalar
|
||||
value, or a callable function which takes ``self`` as a parameter.
|
||||
weights: arraylike, optional. Weights of the same shape as the dataset.
|
||||
"""
|
||||
neff: Any
|
||||
dataset: Any
|
||||
weights: Any
|
||||
covariance: Any
|
||||
inv_cov: Any
|
||||
|
||||
def __init__(self, dataset, bw_method: BwMethod = None, weights=None):
|
||||
check_arraylike("gaussian_kde", dataset)
|
||||
dataset = jnp.atleast_2d(dataset)
|
||||
if dtypes.issubdtype(lax.dtype(dataset), np.complexfloating):
|
||||
raise NotImplementedError("gaussian_kde does not support complex data")
|
||||
if not dataset.size > 1:
|
||||
raise ValueError("`dataset` input should have multiple elements.")
|
||||
|
||||
d, n = dataset.shape
|
||||
if weights is not None:
|
||||
check_arraylike("gaussian_kde", weights)
|
||||
dataset, weights = promote_dtypes_inexact(dataset, weights)
|
||||
weights = jnp.atleast_1d(weights)
|
||||
weights /= jnp.sum(weights)
|
||||
if weights.ndim != 1:
|
||||
raise ValueError("`weights` input should be one-dimensional.")
|
||||
if len(weights) != n:
|
||||
raise ValueError("`weights` input should be of length n")
|
||||
else:
|
||||
dataset, = promote_dtypes_inexact(dataset)
|
||||
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
|
||||
|
||||
self._setattr("dataset", dataset)
|
||||
self._setattr("weights", weights)
|
||||
neff = self._setattr("neff", 1 / jnp.sum(weights**2))
|
||||
|
||||
bw_method = "scott" if bw_method is None else bw_method
|
||||
if bw_method == "scott":
|
||||
factor = jnp.power(neff, -1. / (d + 4))
|
||||
elif bw_method == "silverman":
|
||||
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
|
||||
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
|
||||
factor = cast(Array, bw_method)
|
||||
elif callable(bw_method):
|
||||
factor = bw_method(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
|
||||
)
|
||||
|
||||
data_covariance = jnp.atleast_2d(
|
||||
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
|
||||
data_inv_cov = jnp.linalg.inv(data_covariance)
|
||||
covariance = data_covariance * factor**2
|
||||
inv_cov = data_inv_cov / factor**2
|
||||
self._setattr("covariance", covariance)
|
||||
self._setattr("inv_cov", inv_cov)
|
||||
|
||||
def _setattr(self, name, value):
|
||||
# Frozen dataclasses don't support setting attributes so we have to
|
||||
# overload that operation here as they do in the dataclass implementation
|
||||
object.__setattr__(self, name, value)
|
||||
return value
|
||||
|
||||
def tree_flatten(self):
|
||||
return ((self.neff, self.dataset, self.weights, self.covariance,
|
||||
self.inv_cov), None)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
del aux_data
|
||||
kde = cls.__new__(cls)
|
||||
kde._setattr("neff", children[0])
|
||||
kde._setattr("dataset", children[1])
|
||||
kde._setattr("weights", children[2])
|
||||
kde._setattr("covariance", children[3])
|
||||
kde._setattr("inv_cov", children[4])
|
||||
return kde
|
||||
|
||||
@property
|
||||
def d(self):
|
||||
return self.dataset.shape[0]
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self.dataset.shape[1]
|
||||
|
||||
def evaluate(self, points):
|
||||
"""Evaluate the Gaussian KDE on the given points."""
|
||||
check_arraylike("evaluate", points)
|
||||
points = self._reshape_points(points)
|
||||
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
|
||||
points.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
def __call__(self, points):
|
||||
return self.evaluate(points)
|
||||
|
||||
def integrate_gaussian(self, mean, cov):
|
||||
"""Integrate the distribution weighted by a Gaussian."""
|
||||
mean = jnp.atleast_1d(jnp.squeeze(mean))
|
||||
cov = jnp.atleast_2d(cov)
|
||||
|
||||
if mean.shape != (self.d,):
|
||||
raise ValueError(f"mean does not have dimension {self.d}")
|
||||
if cov.shape != (self.d, self.d):
|
||||
raise ValueError(f"covariance does not have dimension {self.d}")
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + cov)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
|
||||
mean)
|
||||
|
||||
@api.jit
|
||||
def integrate_box_1d(self, low, high):
|
||||
"""Integrate the distribution over the given limits."""
|
||||
if self.d != 1:
|
||||
raise ValueError("integrate_box_1d() only handles 1D pdfs")
|
||||
if np.ndim(low) != 0 or np.ndim(high) != 0:
|
||||
raise ValueError(
|
||||
"the limits of integration in integrate_box_1d must be scalars")
|
||||
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
|
||||
low = jnp.squeeze((low - self.dataset) / sigma)
|
||||
high = jnp.squeeze((high - self.dataset) / sigma)
|
||||
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
|
||||
|
||||
def integrate_kde(self, other):
|
||||
"""Integrate the product of two Gaussian KDE distributions."""
|
||||
if other.d != self.d:
|
||||
raise ValueError("KDEs are not the same dimensionality")
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + other.covariance)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
|
||||
sm, lg = (self, other) if self.n < other.n else (other, self)
|
||||
result = api.vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
|
||||
lg.weights),
|
||||
in_axes=1)(sm.dataset)
|
||||
return jnp.sum(result * sm.weights)
|
||||
|
||||
@partial(api.jit, static_argnames=("shape",))
|
||||
def resample(self, key, shape=()):
|
||||
r"""Randomly sample a dataset from the estimated pdf
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
batch shape; that is, the prefix of the result shape excluding the last
|
||||
axis.
|
||||
|
||||
Returns:
|
||||
The resampled dataset as an array with shape `(d,) + shape`.
|
||||
"""
|
||||
ind_key, eps_key = random.split(key)
|
||||
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
|
||||
eps = random.multivariate_normal(eps_key,
|
||||
jnp.zeros(self.d, self.covariance.dtype),
|
||||
self.covariance,
|
||||
shape=shape,
|
||||
dtype=self.dataset.dtype).T
|
||||
return self.dataset[:, ind] + eps
|
||||
|
||||
def pdf(self, x):
|
||||
"""Probability density function"""
|
||||
return self.evaluate(x)
|
||||
|
||||
def logpdf(self, x):
|
||||
"""Log probability density function"""
|
||||
check_arraylike("logpdf", x)
|
||||
x = self._reshape_points(x)
|
||||
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
|
||||
x.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
def integrate_box(self, low_bounds, high_bounds, maxpts=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del low_bounds, high_bounds, maxpts
|
||||
raise NotImplementedError(
|
||||
"only 1D box integrations are supported; use `integrate_box_1d`")
|
||||
|
||||
def set_bandwidth(self, bw_method=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del bw_method
|
||||
raise NotImplementedError(
|
||||
"dynamically changing the bandwidth method is not supported")
|
||||
|
||||
def _reshape_points(self, points):
|
||||
if dtypes.issubdtype(lax.dtype(points), np.complexfloating):
|
||||
raise NotImplementedError(
|
||||
"gaussian_kde does not support complex coordinates")
|
||||
points = jnp.atleast_2d(points)
|
||||
d, m = points.shape
|
||||
if d != self.d:
|
||||
if d == 1 and m == self.d:
|
||||
points = jnp.reshape(points, (self.d, 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
"points have dimension {}, dataset has dimension {}".format(
|
||||
d, self.d))
|
||||
return points
|
||||
|
||||
|
||||
def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
|
||||
diff = target - mean[:, None]
|
||||
alpha = linalg.cho_solve(chol, diff)
|
||||
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
|
||||
return norm * jnp.sum(jnp.exp(-arg) * weights)
|
||||
|
||||
|
||||
@api.jit(static_argnums=0)
|
||||
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
|
||||
points, values, xi, precision = promote_dtypes_inexact(
|
||||
points, values, xi, precision)
|
||||
d = points.shape[1]
|
||||
|
||||
if xi.shape[1] != d:
|
||||
raise ValueError("points and xi must have same trailing dim")
|
||||
if precision.shape != (d, d):
|
||||
raise ValueError("precision matrix must match data dims")
|
||||
|
||||
whitening = linalg.cholesky(precision, lower=True)
|
||||
points = jnp.dot(points, whitening)
|
||||
xi = jnp.dot(xi, whitening)
|
||||
log_norm = jnp.sum(jnp.log(
|
||||
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)
|
||||
|
||||
def kernel(x_test, x_train, y_train):
|
||||
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
|
||||
if in_log:
|
||||
return jnp.log(y_train) + arg
|
||||
else:
|
||||
return y_train * jnp.exp(arg)
|
||||
|
||||
reduce = special.logsumexp if in_log else jnp.sum
|
||||
reduced_kernel = lambda x: reduce(api.vmap(kernel, in_axes=(None, 0, 0))
|
||||
(x, points, values),
|
||||
axis=0)
|
||||
mapped_kernel = api.vmap(reduced_kernel)
|
||||
|
||||
return mapped_kernel(xi)
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``logpdf``.
|
||||
|
||||
The Laplace probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{2} e^{-|x|}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.cdf`
|
||||
- :func:`jax.scipy.stats.laplace.pdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
|
||||
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``pdf``.
|
||||
|
||||
The Laplace probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{2} e^{-|x|}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.cdf`
|
||||
- :func:`jax.scipy.stats.laplace.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.laplace.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.pdf`
|
||||
- :func:`jax.scipy.stats.laplace.logpdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
|
||||
half = _lax_const(x, 0.5)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(x, 0)
|
||||
diff = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.select(lax.le(diff, zero),
|
||||
lax.mul(half, lax.exp(diff)),
|
||||
lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
|
||||
@@ -0,0 +1,204 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import expit, logit
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``logpdf``.
|
||||
|
||||
The logistic probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
two = _lax_const(x, 2)
|
||||
half_x = lax.div(x, two)
|
||||
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``pdf``.
|
||||
|
||||
The logistic probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.logistic.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
|
||||
return lax.add(lax.mul(logit(x), scale), loc)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``sf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.logistic.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
|
||||
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.logistic.sf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
|
||||
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.logistic.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
|
||||
return expit(lax.div(lax.sub(x, loc), scale))
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_args_numeric
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Multinomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.
|
||||
|
||||
The multinomial probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
|
||||
|
||||
with :math:`n = \sum_i x_i`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multinomial.pmf`
|
||||
"""
|
||||
p, = promote_args_inexact("multinomial.logpmf", p)
|
||||
x, n = promote_args_numeric("multinomial.logpmf", x, n)
|
||||
if not dtypes.issubdtype(x.dtype, np.integer):
|
||||
raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}")
|
||||
x = x.astype(p.dtype)
|
||||
n = n.astype(p.dtype)
|
||||
logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
|
||||
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -np.inf)
|
||||
|
||||
|
||||
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Multinomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multinomial` ``pmf``.
|
||||
|
||||
The multinomial probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
|
||||
|
||||
with :math:`n = \sum_i x_i`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multinomial.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(x, n, p))
|
||||
@@ -0,0 +1,103 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import einsum as jnp_einsum
|
||||
from jax._src.numpy import vectorize as jnp_vectorize
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
|
||||
r"""Multivariate normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``logpdf``.
|
||||
|
||||
The multivariate normal PDF is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
|
||||
|
||||
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
|
||||
:math:`k` is the rank of :math:`\Sigma`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
mean: arraylike, centroid of distribution
|
||||
cov: arraylike, covariance matrix of distribution
|
||||
allow_singular: not supported
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multivariate_normal.pdf`
|
||||
"""
|
||||
if allow_singular is not None:
|
||||
raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf")
|
||||
x, mean, cov = promote_dtypes_inexact(x, mean, cov)
|
||||
if not mean.shape:
|
||||
return (-1/2 * jnp.square(x - mean) / cov
|
||||
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
n = mean.shape[-1]
|
||||
if not np.shape(cov):
|
||||
y = x - mean
|
||||
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) / cov
|
||||
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
|
||||
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
|
||||
L = lax.linalg.cholesky(cov)
|
||||
y = jnp_vectorize.vectorize(
|
||||
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
|
||||
signature="(n,n),(n)->(n)"
|
||||
)(L, x - mean)
|
||||
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
|
||||
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
|
||||
r"""Multivariate normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``pdf``.
|
||||
|
||||
The multivariate normal PDF is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
|
||||
|
||||
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
|
||||
:math:`k` is the rank of :math:`\Sigma`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
mean: arraylike, centroid of distribution
|
||||
cov: arraylike, covariance matrix of distribution
|
||||
allow_singular: not supported
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multivariate_normal.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, mean, cov))
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Negative-binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.nbinom` ``logpmf``.
|
||||
|
||||
The negative-binomial probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
|
||||
|
||||
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.nbinom.pmf`
|
||||
"""
|
||||
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
|
||||
one = _lax_const(k, 1)
|
||||
y = lax.sub(k, loc)
|
||||
comb_term = lax.sub(
|
||||
lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
return jnp.where(lax.lt(k, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Negative-binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.nbinom` ``pmf``.
|
||||
|
||||
The negative-binomial probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
|
||||
|
||||
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.nbinom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy import special
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logpdf``.
|
||||
|
||||
The normal distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
scale_sqrd = lax.square(scale)
|
||||
log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
|
||||
quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
|
||||
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``pdf``.
|
||||
|
||||
The normal distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.norm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
|
||||
return special.ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.norm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
|
||||
return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
"""
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.norm.sf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
|
||||
@@ -0,0 +1,313 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``.
|
||||
|
||||
The Pareto probability density function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, b) = \begin{cases}
|
||||
bx^{-(b+1)} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
normalize_term = lax.log(lax.div(scale, b))
|
||||
log_probs = lax.neg(
|
||||
lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))
|
||||
)
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``pdf``.
|
||||
|
||||
The Pareto probability density function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, b) = \begin{cases}
|
||||
bx^{-(b+1)} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, b, loc, scale))
|
||||
|
||||
|
||||
def cdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``cdf``.
|
||||
|
||||
The Pareto cumulative distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
F(x, b) = \begin{cases}
|
||||
1 - x^{-b} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of CDF values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.cdf", x, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(x, 0)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
cdf = lax.sub(one, lax.pow(scaled_x, lax.neg(b)))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, cdf)
|
||||
|
||||
|
||||
def logcdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logcdf``.
|
||||
|
||||
The Pareto cumulative distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
F(x, b) = \begin{cases}
|
||||
1 - x^{-b} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logCDF values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logcdf", x, b, loc, scale)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
logcdf_val = lax.log1p(lax.neg(lax.pow(scaled_x, lax.neg(b))))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, logcdf_val)
|
||||
|
||||
|
||||
def logsf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logsf``.
|
||||
|
||||
The Pareto survival function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
S(x, b) = \begin{cases}
|
||||
x^{-b} & x \ge 1\\
|
||||
1 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the survival function
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of log survival function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logsf", x, b, loc, scale)
|
||||
zero = _lax_const(x, 0)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
logsf_val = lax.neg(lax.mul(b, lax.log(scaled_x)))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, logsf_val)
|
||||
|
||||
|
||||
def sf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``sf``.
|
||||
|
||||
The Pareto survival function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
S(x, b) = \begin{cases}
|
||||
x^{-b} & x \ge 1\\
|
||||
1 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the survival function
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of survival function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
"""
|
||||
return lax.exp(logsf(x, b, loc, scale))
|
||||
|
||||
|
||||
def ppf(
|
||||
q: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto percent point function (inverse CDF).
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``ppf``.
|
||||
|
||||
The Pareto percent point function is the inverse of the Pareto CDF, and is
|
||||
given by
|
||||
|
||||
.. math::
|
||||
|
||||
F^{-1}(q, b) = \begin{cases}
|
||||
(1 - q)^{-1/b} & 0 \le q < 1\\
|
||||
\text{NaN} & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the inverse CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of percent point function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
q, b, loc, scale = promote_args_inexact("pareto.ppf", q, b, loc, scale)
|
||||
one = _lax_const(q, 1)
|
||||
ppf_val = lax.add(
|
||||
loc, lax.mul(scale, lax.pow(lax.sub(one, q), lax.neg(lax.div(one, b))))
|
||||
)
|
||||
return jnp.where(jnp.isnan(q) | (q < 0) | (q > 1), np.nan, ppf_val)
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike
|
||||
from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``logpmf``.
|
||||
|
||||
The Poisson probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = e^{-\mu}\frac{\mu^k}{k!}
|
||||
|
||||
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.cdf`
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
"""
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero),
|
||||
lax.ne(jnp.round(k), k)), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``pmf``.
|
||||
|
||||
The Poisson probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = e^{-\mu}\frac{\mu^k}{k!}
|
||||
|
||||
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.cdf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, mu, loc))
|
||||
|
||||
|
||||
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``cdf``.
|
||||
|
||||
The cumulative distribution function is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
|
||||
|
||||
where :math:`f_{pmf}(k, p)` is the probability mass function
|
||||
:func:`jax.scipy.stats.poisson.pmf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the CDF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
"""
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
p = gammaincc(jnp.floor(1 + x), mu)
|
||||
return jnp.where(lax.lt(x, zero), zero, p)
|
||||
|
||||
@api.jit
|
||||
def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Shannon entropy of the Poisson distribution.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``entropy``.
|
||||
|
||||
The entropy :math:`H(X)` of a Poisson random variable
|
||||
:math:`X \sim \text{Poisson}(\mu)` is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
H(X) = -\sum_{k=0}^\infty p(k) \log p(k)
|
||||
|
||||
where :math:`p(k) = e^{-\mu} \mu^k / k!` for
|
||||
:math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`.
|
||||
|
||||
This implementation uses **regime switching** for numerical stability
|
||||
and performance:
|
||||
|
||||
- **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive
|
||||
upper bound :math:`k \leq \mu + 20`
|
||||
- **Medium** :math:`10 \leq \mu < 100`: Summation with bound
|
||||
:math:`k \leq \mu + 10\sqrt{\mu} + 20`
|
||||
- **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation:
|
||||
:math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}`
|
||||
|
||||
Matches SciPy to relative error :math:`< 10^{-5}` across all regimes.
|
||||
|
||||
Args:
|
||||
mu: arraylike, mean parameter of the Poisson distribution.
|
||||
Must be ``> 0``.
|
||||
loc: arraylike, optional location parameter (default: 0).
|
||||
Accepted for API compatibility with scipy but does not
|
||||
affect the entropy
|
||||
|
||||
Returns:
|
||||
Array of entropy values with shape broadcast from ``mu`` and ``loc``.
|
||||
Returns ``NaN`` for ``mu <= 0``.
|
||||
|
||||
Examples:
|
||||
>>> from jax.scipy.stats import poisson
|
||||
>>> poisson.entropy(5.0)
|
||||
Array(2.204394, dtype=float32)
|
||||
>>> poisson.entropy(jax.numpy.array([1, 10, 100]))
|
||||
Array([1.3048419, 2.561407 , 3.7206903], dtype=float32)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
- :obj:`scipy.stats.poisson`
|
||||
"""
|
||||
mu, loc = ensure_arraylike("poisson.entropy", mu, loc)
|
||||
promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc)
|
||||
|
||||
#Note: loc does not affect the entropy - translation invariant
|
||||
#it has only been taken to maintain compatibility with scipy api
|
||||
result_shape = jnp.broadcast_shapes(
|
||||
promoted_mu.shape,
|
||||
promoted_loc.shape
|
||||
)
|
||||
|
||||
mu_flat = jnp.ravel(promoted_mu)
|
||||
zero_result = jnp.zeros_like(mu_flat)
|
||||
|
||||
|
||||
# Choose the computation regime based on mu value
|
||||
result = jnp.where(
|
||||
mu_flat == 0,
|
||||
zero_result,
|
||||
jnp.where(
|
||||
mu_flat < 10,
|
||||
_entropy_small_mu(mu_flat),
|
||||
jnp.where(
|
||||
mu_flat < 100,
|
||||
_entropy_medium_mu(mu_flat),
|
||||
_entropy_large_mu(mu_flat)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result_mu_shape = jnp.reshape(result, promoted_mu.shape)
|
||||
|
||||
# Restore original shape
|
||||
return jnp.broadcast_to(result_mu_shape, result_shape)
|
||||
|
||||
def _entropy_small_mu(mu: Array) -> Array:
|
||||
"""Entropy via direct PMF summation for small μ (< 10).
|
||||
Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass.
|
||||
"""
|
||||
max_k = 35
|
||||
|
||||
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
|
||||
probs = pmf(k, mu, 0)
|
||||
|
||||
# Mask: only compute up to mu + 20 for each value
|
||||
upper_bounds = jnp.ceil(mu + 20).astype(k.dtype)
|
||||
mask = k < upper_bounds[None, :]
|
||||
probs_masked = jnp.where(mask, probs, 0.0)
|
||||
|
||||
return jnp.sum(entr(probs_masked), axis=0)
|
||||
|
||||
def _entropy_medium_mu(mu: Array) -> Array:
|
||||
"""Entropy for medium mu (10-100): Adaptive bounds based on std dev.
|
||||
|
||||
Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility.
|
||||
"""
|
||||
max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220
|
||||
|
||||
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
|
||||
probs = pmf(k, mu, 0)
|
||||
|
||||
upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype)
|
||||
mask = k < upper_bounds[None, :]
|
||||
probs_masked = jnp.where(mask, probs, 0.0)
|
||||
|
||||
return jnp.sum(entr(probs_masked), axis=0)
|
||||
|
||||
def _entropy_large_mu(mu: Array) -> Array:
|
||||
"""Entropy for large mu (>= 100): Asymptotic approximation.
|
||||
|
||||
Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2)
|
||||
"""
|
||||
return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu)
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Student's T log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.t` ``logpdf``.
|
||||
|
||||
The Student's T probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
|
||||
|
||||
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
|
||||
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.t.pdf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
df_over_two = lax.div(df, two)
|
||||
df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
|
||||
normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
|
||||
normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
|
||||
normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
|
||||
lax.lgamma(df_plus_one_over_two))
|
||||
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
|
||||
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Student's T probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.t` ``pdf``.
|
||||
|
||||
The Student's T probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
|
||||
|
||||
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
|
||||
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.t.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
@@ -0,0 +1,296 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.stats import norm
|
||||
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
|
||||
|
||||
|
||||
def _log_diff(x, y):
|
||||
return logsumexp(
|
||||
jnp.array([x, y]),
|
||||
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
|
||||
axis=0
|
||||
)
|
||||
|
||||
|
||||
def _log_gauss_mass(a, b):
|
||||
"""Log of Gaussian probability mass within an interval"""
|
||||
a, b = jnp.array(a), jnp.array(b)
|
||||
a, b = jnp.broadcast_arrays(a, b)
|
||||
|
||||
# Note: Docstring carried over from scipy
|
||||
# Calculations in right tail are inaccurate, so we'll exploit the
|
||||
# symmetry and work only in the left tail
|
||||
case_left = b <= 0
|
||||
case_right = a > 0
|
||||
case_central = ~(case_left | case_right)
|
||||
|
||||
# By conditionally swapping arguments if we're in the right tail,
|
||||
# we only need to compile the mass_case_left graph once instead of twice.
|
||||
a_tail = jnp.where(case_right, -b, a)
|
||||
b_tail = jnp.where(case_right, -a, b)
|
||||
|
||||
mass_tail = _log_diff(log_ndtr(b_tail), log_ndtr(a_tail))
|
||||
|
||||
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
|
||||
# Correct for this with an alternative formulation.
|
||||
# We're not concerned with underflow here: if only one term
|
||||
# underflows, it was insignificant; if both terms underflow,
|
||||
# the result can't accurately be represented in logspace anyway
|
||||
# because sc.log1p(x) ~ x for small x.
|
||||
mass_central = jnp.log1p(-ndtr(a) - ndtr(-b))
|
||||
|
||||
out = jnp.where(case_central, mass_central, mass_tail)
|
||||
return out
|
||||
|
||||
|
||||
@api.jit
|
||||
def logpdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``.
|
||||
|
||||
The truncated normal probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \begin{cases}
|
||||
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
where :math:`a` and :math:`b` are effectively specified in number of
|
||||
standard deviations from zero. JAX uses the scipy nomenclature
|
||||
of ``loc`` for the centroid and ``scale`` for the standard deviation.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
|
||||
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
|
||||
|
||||
x_scaled = lax.div(lax.sub(x, loc), scale)
|
||||
val = jnp.where((x_scaled < a) | (x_scaled > b), -np.inf, val)
|
||||
val = jnp.where(a >= b, np.nan, val)
|
||||
return val
|
||||
|
||||
|
||||
def pdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``.
|
||||
|
||||
The truncated normal probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \begin{cases}
|
||||
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
where :math:`a` and :math:`b` are effectively specified in number of
|
||||
standard deviations from the centroid. JAX uses the scipy nomenclature
|
||||
of ``loc`` for the centroid and ``scale`` for the standard deviation.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
@api.jit
|
||||
def logsf(x, a, b, loc=0, scale=1):
|
||||
"""Truncated normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||
return logcdf(-x, -b, -a, -loc, scale)
|
||||
|
||||
|
||||
@api.jit
|
||||
def sf(x, a, b, loc=0, scale=1):
|
||||
"""Truncated normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
"""
|
||||
return lax.exp(logsf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@api.jit
|
||||
def logcdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
lgm_ab = _log_gauss_mass(a, b)
|
||||
logcdf = _log_gauss_mass(a, x) - lgm_ab
|
||||
logsf = _log_gauss_mass(x, b) - lgm_ab
|
||||
|
||||
logcdf = jnp.select(
|
||||
# third condition: avoid catastrophic cancellation (from scipy)
|
||||
[x >= b, x <= a, logcdf > -0.1, x > a],
|
||||
[0, -np.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
|
||||
)
|
||||
logcdf = jnp.where(a >= b, np.nan, logcdf)
|
||||
return logcdf
|
||||
|
||||
@api.jit
|
||||
def cdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, a, b, loc, scale))
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``logpdf``.
|
||||
|
||||
The uniform distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
1 & 0 \le x \le 1 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
|
||||
log_probs = lax.neg(lax.log(scale))
|
||||
return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
|
||||
lax.lt(x, loc)),
|
||||
-np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``pdf``.
|
||||
|
||||
The uniform distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
1 & 0 \le x \le 1 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.uniform.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
|
||||
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
|
||||
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
|
||||
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
|
||||
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Uniform distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.uniform.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
np.nan,
|
||||
lax.add(loc, lax.mul(scale, q))
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
r"""von Mises log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.vonmises` ``logpdf``.
|
||||
|
||||
The von Mises probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
|
||||
|
||||
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
|
||||
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
|
||||
:math:`-\pi \le x \le \pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
kappa: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.vonmises.pdf`
|
||||
"""
|
||||
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
|
||||
zero = _lax_const(kappa, 0)
|
||||
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * np.pi * lax.bessel_i0e(kappa)), np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
r"""von Mises probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.vonmises` ``pdf``.
|
||||
|
||||
The von Mises probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
|
||||
|
||||
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
|
||||
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
|
||||
:math:`-\pi \le x \le \pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
kappa: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.vonmises.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, kappa))
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
r"""Wrapped Cauchy log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``logpdf``.
|
||||
|
||||
The wrapped Cauchy probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
|
||||
|
||||
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
c: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.wrapcauchy.pdf`
|
||||
"""
|
||||
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
|
||||
return jnp.where(
|
||||
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
|
||||
jnp.where(
|
||||
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, np.pi * 2)),
|
||||
jnp.log(1 - c * c) - jnp.log(2 * np.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
|
||||
-np.inf,
|
||||
),
|
||||
np.nan,
|
||||
)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
r"""Wrapped Cauchy probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``pdf``.
|
||||
|
||||
The wrapped Cauchy probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
|
||||
|
||||
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
c: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.wrapcauchy.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, c))
|
||||
Reference in New Issue
Block a user