This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,311 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import NamedTuple
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import check_arraylike, promote_args_inexact
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis
class ModeResult(NamedTuple):
mode: Array
count: Array
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
"""Compute the mode (most common value) along an axis of an array.
JAX implementation of :func:`scipy.stats.mode`.
Args:
a: arraylike
axis: int, default=0. Axis along which to compute the mode.
nan_policy: str. JAX only supports ``"propagate"``.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values,
and ``count`` is the number of times each value appears in the input array.
Examples:
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))
For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode``
and the corresponding ``count`` along ``axis=0``:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
... [3, 1, 3, 2, 1, 3],
... [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
By default, ``jax.scipy.stats.mode`` reduces the dimension of the result.
To keep the dimensions same as that of the input array, the argument
``keepdims`` must be set to ``True``.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
[3],
[2]], dtype=int32), Array([[3],
[3],
[3]], dtype=int32))
"""
check_arraylike("mode", a)
x = jnp.atleast_1d(a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(
f"Illegal nan_policy value {nan_policy!r}; expected one of "
"{'propagate', 'omit', 'raise'}"
)
if nan_policy == "omit":
# TODO: return answer without nans included.
raise NotImplementedError(
f"Logic for `nan_policy` of {nan_policy} is not implemented"
)
if nan_policy == "raise":
raise NotImplementedError(
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
"Please check if nans exist in `x` outside of the `mode` function."
)
if axis is not None:
axis = canonicalize_axis(axis, x.ndim)
input_shape = x.shape
if keepdims:
if axis is None:
output_shape = tuple(1 for i in input_shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
else:
if axis is None:
output_shape = ()
else:
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
if axis is None:
axis = 0
x = x.ravel()
def _mode_helper(x: Array) -> tuple[Array, Array]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return (jnp.array(np.nan, dtype=dtypes.default_float_dtype()),
jnp.array(0, dtype=dtypes.default_float_dtype()))
else:
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
return vals[jnp.argmax(counts)], counts.max()
x = jnp.moveaxis(x, axis, 0)
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
vals, counts = api.vmap(_mode_helper, in_axes=1)(x)
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
def invert_permutation(i: Array) -> Array:
"""Helper function that inverts a permutation array."""
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
@api.jit(static_argnames=["method", "axis", "nan_policy"])
def rankdata(
a: ArrayLike,
method: str = "average",
*,
axis: int | None = None,
nan_policy: str = "propagate",
) -> Array:
"""Compute the rank of data along an array axis.
JAX implementation of :func:`scipy.stats.rankdata`.
Ranks begin at 1, and the *method* argument controls how ties are handled.
Args:
a: arraylike
method: str, default="average". Supported methods are
``("average", "min", "max", "dense", "ordinal")``
For details, see the :func:`scipy.stats.rankdata` documentation.
axis: optional integer. If not specified, the input array is flattened.
nan_policy: str, JAX's implementation only supports ``"propagate"``.
Returns:
array of ranks along the specified axis.
Examples:
>>> x = jnp.array([10, 30, 20])
>>> rankdata(x)
Array([1., 3., 2.], dtype=float32)
>>> x = jnp.array([1, 3, 2, 3])
>>> rankdata(x)
Array([1. , 3.5, 2. , 3.5], dtype=float32)
"""
check_arraylike("rankdata", a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(
f"Illegal nan_policy value {nan_policy!r}; expected one of "
"{'propagate', 'omit', 'raise'}"
)
if nan_policy == "omit":
raise NotImplementedError(
f"Logic for `nan_policy` of {nan_policy} is not implemented"
)
if nan_policy == "raise":
raise NotImplementedError(
"In order to best JIT compile `rankdata`, we cannot know whether `x` "
"contains nans. Please check if nans exist in `x` outside of the "
"`rankdata` function."
)
if method not in ("average", "min", "max", "dense", "ordinal"):
raise ValueError(f"unknown method '{method}'")
if axis is not None:
return jnp.apply_along_axis(rankdata, axis, a, method)
a = jnp.ravel(a)
out_dtype = dtypes.default_float_dtype()
def _rankdata(a: Array) -> Array:
arr, sorter = lax.sort_key_val(a, jnp.arange(a.size))
inv = invert_permutation(sorter)
if method == "ordinal":
return (inv + 1).astype(out_dtype)
obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]])
dense = obs.cumsum()[inv]
if method == "dense":
return dense.astype(out_dtype)
count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0].astype(out_dtype)
if method == "max":
return count[dense]
if method == "min":
return count[dense - 1] + 1
if method == "average":
return .5 * (count[dense] + count[dense - 1] + 1)
raise ValueError(f"unknown method '{method}'")
return lax.cond(jnp.any(jnp.isnan(a)),
lambda a: jnp.full_like(a, jnp.nan, out_dtype),
_rankdata, a)
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
"""Compute the standard error of the mean.
JAX implementation of :func:`scipy.stats.sem`.
Args:
a: arraylike
axis: optional integer. If not specified, the input array is flattened.
ddof: integer, default=1. The degrees of freedom in the SEM computation.
nan_policy: str, default="propagate". JAX supports only "propagate" and
"omit".
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
array
Examples:
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x)
Array(0.41, dtype=float32)
For multi dimensional arrays, ``sem`` computes standard error of mean along
``axis=0``:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
... [3, 1, 3, 2, 1, 3],
... [1, 2, 2, 3, 1, 2]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1)
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
If ``axis=1``, standard error of mean will be computed along ``axis 1``.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=1)
Array([0.33, 0.4 , 0.31], dtype=float32)
If ``axis=None``, standard error of mean will be computed along all the axes.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=None)
Array(0.2, dtype=float32)
By default, ``sem`` reduces the dimension of the result. To keep the
dimensions same as that of the input array, the argument ``keepdims`` must
be set to ``True``.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=1, keepdims=True)
Array([[0.33],
[0.4 ],
[0.31]], dtype=float32)
Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan``
values in the result.
>>> nan = np.nan
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
... [4, 5, 4, 3, nan, 1],
... [7, nan, 8, 7, 9, nan]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x2)
Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error
for the remaining values along the specified axis.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x2, nan_policy='omit')
Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
"""
b, = promote_args_inexact("sem", a)
if nan_policy == "propagate":
size = b.size if axis is None else b.shape[axis]
return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype)
elif nan_policy == "omit":
count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims)
return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype)
else:
raise ValueError(f"{nan_policy} is not supported")
@@ -0,0 +1,159 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Bernoulli log probability mass function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf``
The Bernoulli probability mass function is defined as
.. math::
f(k) = \begin{cases}
1 - p, & k = 0 \\
p, & k = 1 \\
0, & \mathrm{otherwise}
\end{cases}
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of logpmf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.pmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
-np.inf, log_probs)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Bernoulli probability mass function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf``
The Bernoulli probability mass function is defined as
.. math::
f(k) = \begin{cases}
1 - p, & k = 0 \\
p, & k = 1 \\
0, & \mathrm{otherwise}
\end{cases}
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of pmf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
return jnp.exp(logpmf(k, p, loc))
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
r"""Bernoulli cumulative distribution function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``cdf``
The Bernoulli cumulative distribution function is defined as:
.. math::
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
where :math:`f_{pmf}(k, p)` is the Bernoulli probability mass function
:func:`jax.scipy.stats.bernoulli.pmf`.
Args:
k: arraylike, value at which to evaluate the CDF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.pmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
k, p = promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
conds = [
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
lax.lt(k, zero),
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
lax.ge(k, one)
]
vals = [jnp.nan, zero, one - p, one]
return jnp.select(conds, vals)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
"""Bernoulli percent point function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``ppf``
The percent point function is the inverse of the cumulative
distribution function, :func:`jax.scipy.stats.bernoulli.cdf`.
Args:
k: arraylike, value at which to evaluate the PPF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.pmf`
"""
q, p = promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)
return jnp.where(
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
jnp.nan,
jnp.where(lax.le(q, one - p), zero, one)
)
@@ -0,0 +1,262 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import betaln, betainc, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta log probability distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``logpdf``.
The pdf of the beta function is:
.. math::
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function,
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
one = _lax_const(x, 1)
zero = _lax_const(a, 0)
shape_term = lax.neg(betaln(a, b))
y = lax.div(lax.sub(x, loc), scale)
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
xlog1py(lax.sub(b, one), lax.neg(y)))
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)), -np.inf, log_probs)
result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)),
lax.le(scale, zero)), np.nan, result)
return result_positive_constants
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta probability distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``pdf``.
The pdf of the beta function is:
.. math::
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
return lax.exp(logpdf(x, a, b, loc, scale))
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta cumulative distribution function
JAX implementation of :obj:`scipy.stats.beta` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
where :math:`f_{pdf}` is the beta distribution probability density function,
:func:`jax.scipy.stats.beta.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
return betainc(
a,
b,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
where :math:`f_{pdf}` is the beta distribution probability density function,
:func:`jax.scipy.stats.beta.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
return lax.log(cdf(x, a, b, loc, scale))
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta distribution survival function.
JAX implementation of :obj:`scipy.stats.beta` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
:func:`jax.scipy.stats.beta.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
return betainc(
b,
a,
1 - lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta distribution log survival function.
JAX implementation of :obj:`scipy.stats.beta` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
:func:`jax.scipy.stats.beta.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
"""
return lax.log(sf(x, a, b, loc, scale))
@@ -0,0 +1,98 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import betaln
from jax._src.typing import Array, ArrayLike
@api.jit
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
r"""Beta-binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.betabinom` ``logpmf``
The beta-binomial distribution's probability mass function is defined as
.. math::
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values
See Also:
:func:`jax.scipy.stats.betabinom.pmf`
"""
k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
y = lax.sub(lax.floor(k), loc)
one = _lax_const(y, 1)
zero = _lax_const(y, 0)
combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one))))
beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b))
log_probs = lax.add(combiln, beta_lns)
log_probs = jnp.where(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), 0., log_probs)
y_cond = jnp.logical_or(jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, n)),
lax.le(lax.add(y, a), zero))
log_probs = jnp.where(y_cond, -np.inf, log_probs)
n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, zero), lax.le(a, zero)), lax.le(b, zero))
return jnp.where(n_a_b_cond, np.nan, log_probs)
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
r"""Beta-binomial probability mass function.
JAX implementation of :obj:`scipy.stats.betabinom` ``pmf``.
The beta-binomial distribution's probability mass function is defined as
.. math::
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values
See Also:
:func:`jax.scipy.stats.betabinom.logpmf`
"""
return lax.exp(logpmf(k, n, a, b, loc))
@@ -0,0 +1,90 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.lax.lax import _const as _lax_const
from jax._src.scipy.special import gammaln, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.binom` ``logpmf``.
The binomial probability mass function is defined as
.. math::
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.binom.pmf`
"""
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
y = lax.sub(k, loc)
zero = _lax_const(y, 0)
comb_term = lax.sub(
gammaln(n + 1),
lax.add(gammaln(y + 1), gammaln(n - y + 1))
)
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
log_probs = lax.add(comb_term, log_linear_term)
y_n_cond = jnp.logical_or(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)),
lax.eq(log_linear_term, zero))
log_probs = jnp.where(y_n_cond, 0., log_probs)
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -np.inf)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Binomial probability mass function.
JAX implementation of :obj:`scipy.stats.binom` ``pmf``.
The binomial probability mass function is defined as
.. math::
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.binom.logpmf`
"""
return lax.exp(logpmf(k, n, p, loc))
@@ -0,0 +1,293 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.ufuncs import arctan
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy log probability distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logpdf``.
The Cauchy probability distribution function is
.. math::
f(x) = \frac{1}{\pi(1 + x^2)}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.mul(pi, scale))
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy probability distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``pdf``.
The Cauchy probability distribution function is
.. math::
f(x) = \frac{1}{\pi(1 + x^2)}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy cumulative distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
:func:`jax.scipy.stats.cauchy.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logcdf``
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
:func:`jax.scipy.stats.cauchy.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
return lax.log(cdf(x, loc, scale))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution log survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.cauchy.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
return cdf(-x, -loc, scale)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution log survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.cauchy.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.cauchy.sf`.
Args:
q: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
return lax.add(lax.mul(unscaled, scale), loc)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution percent point function.
JAX implementation of :obj:`scipy.stats.cauchy` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.cauchy.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
"""
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
return lax.add(lax.mul(unscaled, scale), loc)
@@ -0,0 +1,267 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammainc, gammaincc
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log probability distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``logpdf``.
The chi-square probability distribution function is given by:
.. math::
f(x, k) = \begin{cases}
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square probability distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``pdf``.
The chi-square probability distribution function is given by:
.. math::
f(x, k) = \begin{cases}
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
return lax.exp(logpdf(x, df, loc, scale))
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square cumulative distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the CDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammainc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, np.inf),
),
)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the CDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
return lax.log(cdf(x, df, loc, scale))
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square survival function.
JAX implementation of :obj:`scipy.stats.chi2` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the SF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammaincc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, np.inf),
),
)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log survival function.
JAX implementation of :obj:`scipy.stats.chi2` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the SF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
"""
return lax.log(sf(x, df, loc, scale))
@@ -0,0 +1,100 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def _is_simplex(x: Array) -> Array:
x_sum = jnp.sum(x, axis=0)
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet log probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.pdf`
"""
return _logpdf(*promote_dtypes_inexact(x, alpha))
def _logpdf(x: Array, alpha: Array) -> Array:
if alpha.ndim != 1:
raise ValueError(
f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
)
if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
raise ValueError(
"`x` must have either the same number of entries as `alpha` "
f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
)
one = _lax_const(x, 1)
if x.shape[0] != alpha.shape[0]:
x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
if x.ndim > 1:
alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
return jnp.where(_is_simplex(x), log_probs, -np.inf)
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.logpdf`
"""
return lax.exp(logpdf(x, alpha))
@@ -0,0 +1,270 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log probability distribution function.
JAX implementation of :obj:`scipy.stats.expon` ``logpdf``.
The Exponential probability distribution function is
.. math::
f(x) = \begin{cases}
e^{-x} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
linear_term = lax.div(lax.sub(x, loc), scale)
log_probs = lax.neg(lax.add(linear_term, log_scale))
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential probability distribution function.
JAX implementation of :obj:`scipy.stats.expon` ``pdf``.
The Exponential probability distribution function is
.. math::
f(x) = \begin{cases}
e^{-x} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(
lax.lt(x, loc),
jnp.zeros_like(neg_scaled_x),
lax.neg(lax.expm1(neg_scaled_x)),
)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.log1p(lax.neg(sf(x, loc, scale)))
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log survival function.
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logsf(x, loc, scale))
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
np.nan,
lax.neg(lax.log1p(neg_scaled_q)),
)
@@ -0,0 +1,237 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy, gammainc, gammaincc
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log probability distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``logpdf``.
The Gamma probability distribution is given by
.. math::
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`x \ge 0` and :math:`a > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
ok = lax.ge(x, loc)
one = _lax_const(x, 1)
y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
shape_terms = lax.add(gammaln(a), lax.log(scale))
log_probs = lax.sub(log_linear_term, shape_terms)
return jnp.where(ok, log_probs, -np.inf)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma probability distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
The Gamma probability distribution is given by
.. math::
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`x \ge 0` and :math:`a > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
return lax.exp(logpdf(x, a, loc, scale))
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gamma.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
a,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, np.inf),
)
)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gamma.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
return lax.log(cdf(x, a, loc, scale))
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma survival function.
JAX implementation of :obj:`scipy.stats.gamma` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.gamma.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
y = lax.div(lax.sub(x, loc), scale)
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log survival function.
JAX implementation of :obj:`scipy.stats.gamma` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.gamma.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
"""
return lax.log(sf(x, a, loc, scale))
@@ -0,0 +1,103 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``logpdf``.
The generalized normal probability distribution function is defined as
.. math::
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
:math:`\beta > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
beta: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.cdf`
- :func:`jax.scipy.stats.gennorm.pdf`
"""
x, beta = promote_args_inexact("gennorm.logpdf", x, beta)
return lax.log(.5 * beta) - lax.lgamma(1/beta) - lax.abs(x)**beta
def cdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gennorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
beta: arraylike, distribution shape parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.pdf`
- :func:`jax.scipy.stats.gennorm.logpdf`
"""
x, beta = promote_args_inexact("gennorm.cdf", x, beta)
return .5 * (1 + lax.sign(x) * lax.igamma(1/beta, lax.abs(x)**beta))
def pdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal probability distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``pdf``.
The generalized normal probability distribution function is defined as
.. math::
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
:math:`\beta > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
beta: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.cdf`
- :func:`jax.scipy.stats.gennorm.logpdf`
"""
return lax.exp(logpdf(x, beta))
@@ -0,0 +1,81 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Geometric log probability mass function.
JAX implementation of :obj:`scipy.stats.geom` ``logpmf``.
The Geometric probability mass function is given by
.. math::
f(k) = (1 - p)^{k-1}p
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.geom.pmf`
"""
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
return jnp.where(lax.le(x, zero), -np.inf, log_probs)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Geometric probability mass function.
JAX implementation of :obj:`scipy.stats.geom` ``pmf``.
The Geometric probability mass function is given by
.. math::
f(k) = (1 - p)^{k-1}p
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.geom.logpmf`
"""
return jnp.exp(logpmf(k, p, loc))
@@ -0,0 +1,256 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax._src.scipy.special import xlogy, xlog1py
def logpdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(pdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logpdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logpdf = -log(scale) + z - exp(z)
z = lax.div(lax.sub(x, loc), scale)
neg_log_scale = xlogy(-1, scale)
t2 = lax.sub(z, lax.exp(z))
log_pdf = lax.add(neg_log_scale, t2)
return jnp.where(ok, log_pdf, np.nan)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``pdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate pdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
return lax.exp(logpdf(x, loc, scale))
def logcdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logcdf``.
.. math::
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(cdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
neg_exp_z = lax.neg(lax.exp(z))
# xlog1p fails here, that's why log1p is used here
# even log1p fails for some cases when using float64 mode
# so we're using this formula which is stable
log_cdf = lax.log(-lax.expm1(neg_exp_z))
return jnp.where(ok, log_cdf, np.nan)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``cdf``.
.. math::
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate cdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
return lax.exp(logcdf(x, loc, scale))
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) percent point function (inverse of CDF)
JAX implementation of :obj:`scipy.stats.gumbel_l` ``ppf``.
.. math::
F_{ppf}}(p; \mu, \beta) = \mu + \beta \log\left( -\log(1 - p) \right)
Args:
p: ArrayLike, probability value (quantile) at which to evaluate ppf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
p, loc, scale = promote_args_inexact("gumbel_l.ppf", p, loc, scale)
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
lax.lt(p, _lax_const(p, 1)))
# quantile = loc + (scale)*log(-log(1 - p))
t1 = xlog1py(-1, lax.neg(p))
# xlogp failed here too, that's why log is used
t = lax.mul(scale, lax.log(t1))
quantile = lax.add(loc, t)
return jnp.where(ok, quantile, np.nan)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) survival function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``sf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of sf values (1 - cdf)
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
"""
return jnp.exp(logsf(x, loc, scale))
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log survival function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logsf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate log survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logsf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logsf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logsf = -exp(z)
z = lax.div(lax.sub(x, loc), scale)
log_sf = lax.neg(lax.exp(z))
return jnp.where(ok, log_sf, np.nan)
@@ -0,0 +1,257 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax._src.scipy.special import xlogy
from jax._src.nn.functions import log1mexp
def logpdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(pdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logpdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
# logpdf = -log(beta) - (z + exp(-z))
neg_log_scale = xlogy(-1, scale)
t2 = lax.neg(lax.add(z, lax.exp(lax.neg(z))))
log_pdf = lax.add(neg_log_scale, t2)
return jnp.where(ok, log_pdf, np.nan)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``pdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate pdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
return lax.exp(logpdf(x, loc, scale))
def logcdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logcdf``.
.. math::
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(cdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logcdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
# log cdf = -exp(-z)
log_cdf = lax.neg(lax.exp(lax.neg(z)))
return jnp.where(ok, log_cdf, np.nan)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``cdf``.
.. math::
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate cdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
return lax.exp(logcdf(x, loc, scale))
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) percent point function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``ppf``.
.. math::
F(p; \mu, \beta) = \mu - \beta \log\left( -\log(p) \right)
Args:
p: ArrayLike, probability value (quantile) at which to evaluate ppf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
p, loc, scale = promote_args_inexact("gumbel_r.ppf", p, loc, scale)
# 0 < p < 1
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
lax.lt(p, _lax_const(p, 1)))
# quantile = loc - (scale)*log(-log(p))
t1 = xlogy(-1, p)
t = lax.mul(scale, lax.log(t1))
quantile = lax.sub(loc, t)
return jnp.where(ok, quantile, np.nan)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) survival function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``sf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - F_{cdf}(x; \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of sf values (1 - cdf)
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.sf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# sf = 1 - exp(-exp(-z))
neg_z = lax.div(lax.sub(loc, x), scale)
t1 = lax.exp(lax.neg(lax.exp(neg_z)))
_sf = lax.sub(_lax_const(x, 1), t1)
return jnp.where(ok, _sf, np.nan)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log survival function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logsf``.
Args:
x: ArrayLike, value at which to evaluate log survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logsf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logsf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logsf = log(1 - exp(-exp(-z)))
neg_z = lax.div(lax.sub(loc, x), scale)
log_sf = log1mexp(lax.exp(neg_z))
return jnp.where(ok, log_sf, np.nan)
@@ -0,0 +1,284 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, cast
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src import random
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.scipy import linalg, special
from jax._src.tree_util import register_pytree_node_class
from jax._src.typing import Array
BwMethod = None | str | Array | Callable[[Any], Array]
@register_pytree_node_class
@dataclass(frozen=True, init=False)
class gaussian_kde:
"""Gaussian Kernel Density Estimator
JAX implementation of :class:`scipy.stats.gaussian_kde`.
Parameters:
dataset: arraylike, real-valued. Data from which to estimate the distribution.
If 1D, shape is (n_data,). If 2D, shape is (n_dimensions, n_data).
bw_method: string, scalar, or callable. Either "scott", "silverman", a scalar
value, or a callable function which takes ``self`` as a parameter.
weights: arraylike, optional. Weights of the same shape as the dataset.
"""
neff: Any
dataset: Any
weights: Any
covariance: Any
inv_cov: Any
def __init__(self, dataset, bw_method: BwMethod = None, weights=None):
check_arraylike("gaussian_kde", dataset)
dataset = jnp.atleast_2d(dataset)
if dtypes.issubdtype(lax.dtype(dataset), np.complexfloating):
raise NotImplementedError("gaussian_kde does not support complex data")
if not dataset.size > 1:
raise ValueError("`dataset` input should have multiple elements.")
d, n = dataset.shape
if weights is not None:
check_arraylike("gaussian_kde", weights)
dataset, weights = promote_dtypes_inexact(dataset, weights)
weights = jnp.atleast_1d(weights)
weights /= jnp.sum(weights)
if weights.ndim != 1:
raise ValueError("`weights` input should be one-dimensional.")
if len(weights) != n:
raise ValueError("`weights` input should be of length n")
else:
dataset, = promote_dtypes_inexact(dataset)
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
self._setattr("dataset", dataset)
self._setattr("weights", weights)
neff = self._setattr("neff", 1 / jnp.sum(weights**2))
bw_method = "scott" if bw_method is None else bw_method
if bw_method == "scott":
factor = jnp.power(neff, -1. / (d + 4))
elif bw_method == "silverman":
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
factor = cast(Array, bw_method)
elif callable(bw_method):
factor = bw_method(self)
else:
raise ValueError(
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
)
data_covariance = jnp.atleast_2d(
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
data_inv_cov = jnp.linalg.inv(data_covariance)
covariance = data_covariance * factor**2
inv_cov = data_inv_cov / factor**2
self._setattr("covariance", covariance)
self._setattr("inv_cov", inv_cov)
def _setattr(self, name, value):
# Frozen dataclasses don't support setting attributes so we have to
# overload that operation here as they do in the dataclass implementation
object.__setattr__(self, name, value)
return value
def tree_flatten(self):
return ((self.neff, self.dataset, self.weights, self.covariance,
self.inv_cov), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
kde = cls.__new__(cls)
kde._setattr("neff", children[0])
kde._setattr("dataset", children[1])
kde._setattr("weights", children[2])
kde._setattr("covariance", children[3])
kde._setattr("inv_cov", children[4])
return kde
@property
def d(self):
return self.dataset.shape[0]
@property
def n(self):
return self.dataset.shape[1]
def evaluate(self, points):
"""Evaluate the Gaussian KDE on the given points."""
check_arraylike("evaluate", points)
points = self._reshape_points(points)
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
points.T, self.inv_cov)
return result[:, 0]
def __call__(self, points):
return self.evaluate(points)
def integrate_gaussian(self, mean, cov):
"""Integrate the distribution weighted by a Gaussian."""
mean = jnp.atleast_1d(jnp.squeeze(mean))
cov = jnp.atleast_2d(cov)
if mean.shape != (self.d,):
raise ValueError(f"mean does not have dimension {self.d}")
if cov.shape != (self.d, self.d):
raise ValueError(f"covariance does not have dimension {self.d}")
chol = linalg.cho_factor(self.covariance + cov)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
mean)
@api.jit
def integrate_box_1d(self, low, high):
"""Integrate the distribution over the given limits."""
if self.d != 1:
raise ValueError("integrate_box_1d() only handles 1D pdfs")
if np.ndim(low) != 0 or np.ndim(high) != 0:
raise ValueError(
"the limits of integration in integrate_box_1d must be scalars")
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
low = jnp.squeeze((low - self.dataset) / sigma)
high = jnp.squeeze((high - self.dataset) / sigma)
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
def integrate_kde(self, other):
"""Integrate the product of two Gaussian KDE distributions."""
if other.d != self.d:
raise ValueError("KDEs are not the same dimensionality")
chol = linalg.cho_factor(self.covariance + other.covariance)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm
sm, lg = (self, other) if self.n < other.n else (other, self)
result = api.vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
lg.weights),
in_axes=1)(sm.dataset)
return jnp.sum(result * sm.weights)
@partial(api.jit, static_argnames=("shape",))
def resample(self, key, shape=()):
r"""Randomly sample a dataset from the estimated pdf
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis.
Returns:
The resampled dataset as an array with shape `(d,) + shape`.
"""
ind_key, eps_key = random.split(key)
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
eps = random.multivariate_normal(eps_key,
jnp.zeros(self.d, self.covariance.dtype),
self.covariance,
shape=shape,
dtype=self.dataset.dtype).T
return self.dataset[:, ind] + eps
def pdf(self, x):
"""Probability density function"""
return self.evaluate(x)
def logpdf(self, x):
"""Log probability density function"""
check_arraylike("logpdf", x)
x = self._reshape_points(x)
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
x.T, self.inv_cov)
return result[:, 0]
def integrate_box(self, low_bounds, high_bounds, maxpts=None):
"""This method is not implemented in the JAX interface."""
del low_bounds, high_bounds, maxpts
raise NotImplementedError(
"only 1D box integrations are supported; use `integrate_box_1d`")
def set_bandwidth(self, bw_method=None):
"""This method is not implemented in the JAX interface."""
del bw_method
raise NotImplementedError(
"dynamically changing the bandwidth method is not supported")
def _reshape_points(self, points):
if dtypes.issubdtype(lax.dtype(points), np.complexfloating):
raise NotImplementedError(
"gaussian_kde does not support complex coordinates")
points = jnp.atleast_2d(points)
d, m = points.shape
if d != self.d:
if d == 1 and m == self.d:
points = jnp.reshape(points, (self.d, 1))
else:
raise ValueError(
"points have dimension {}, dataset has dimension {}".format(
d, self.d))
return points
def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
diff = target - mean[:, None]
alpha = linalg.cho_solve(chol, diff)
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
return norm * jnp.sum(jnp.exp(-arg) * weights)
@api.jit(static_argnums=0)
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
points, values, xi, precision = promote_dtypes_inexact(
points, values, xi, precision)
d = points.shape[1]
if xi.shape[1] != d:
raise ValueError("points and xi must have same trailing dim")
if precision.shape != (d, d):
raise ValueError("precision matrix must match data dims")
whitening = linalg.cholesky(precision, lower=True)
points = jnp.dot(points, whitening)
xi = jnp.dot(xi, whitening)
log_norm = jnp.sum(jnp.log(
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)
def kernel(x_test, x_train, y_train):
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
if in_log:
return jnp.log(y_train) + arg
else:
return y_train * jnp.exp(arg)
reduce = special.logsumexp if in_log else jnp.sum
reduced_kernel = lambda x: reduce(api.vmap(kernel, in_axes=(None, 0, 0))
(x, points, values),
axis=0)
mapped_kernel = api.vmap(reduced_kernel)
return mapped_kernel(xi)
@@ -0,0 +1,109 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace log probability distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``logpdf``.
The Laplace probability distribution function is given by
.. math::
f(x) = \frac{1}{2} e^{-|x|}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.laplace.cdf`
- :func:`jax.scipy.stats.laplace.pdf`
"""
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
two = _lax_const(x, 2)
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace probability distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``pdf``.
The Laplace probability distribution function is given by
.. math::
f(x) = \frac{1}{2} e^{-|x|}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.laplace.cdf`
- :func:`jax.scipy.stats.laplace.logpdf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace cumulative distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.laplace.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.laplace.pdf`
- :func:`jax.scipy.stats.laplace.logpdf`
"""
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
half = _lax_const(x, 0.5)
one = _lax_const(x, 1)
zero = _lax_const(x, 0)
diff = lax.div(lax.sub(x, loc), scale)
return lax.select(lax.le(diff, zero),
lax.mul(half, lax.exp(diff)),
lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
@@ -0,0 +1,204 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import expit, logit
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic log probability distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``logpdf``.
The logistic probability distribution function is given by
.. math::
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic probability distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``pdf``.
The logistic probability distribution function is given by
.. math::
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution percent point function.
JAX implementation of :obj:`scipy.stats.logistic` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.logistic.cdf`.
Args:
x: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
"""
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution survival function.
JAX implementation of :obj:`scipy.stats.logistic` ``sf``
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.logistic.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.logistic` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.logistic.sf`.
Args:
x: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic cumulative distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.logistic.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))
@@ -0,0 +1,83 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact, promote_args_numeric
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
r"""Multinomial log probability mass function.
JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.
The multinomial probability distribution is given by
.. math::
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
with :math:`n = \sum_i x_i`.
Args:
x: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.multinomial.pmf`
"""
p, = promote_args_inexact("multinomial.logpmf", p)
x, n = promote_args_numeric("multinomial.logpmf", x, n)
if not dtypes.issubdtype(x.dtype, np.integer):
raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}")
x = x.astype(p.dtype)
n = n.astype(p.dtype)
logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -np.inf)
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
r"""Multinomial probability mass function.
JAX implementation of :obj:`scipy.stats.multinomial` ``pmf``.
The multinomial probability distribution is given by
.. math::
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
with :math:`n = \sum_i x_i`.
Args:
x: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
Returns:
array of pmf values
See Also:
:func:`jax.scipy.stats.multinomial.logpmf`
"""
return lax.exp(logpmf(x, n, p))
@@ -0,0 +1,103 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import einsum as jnp_einsum
from jax._src.numpy import vectorize as jnp_vectorize
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
r"""Multivariate normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``logpdf``.
The multivariate normal PDF is defined as
.. math::
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
:math:`k` is the rank of :math:`\Sigma`.
Args:
x: arraylike, value at which to evaluate the PDF
mean: arraylike, centroid of distribution
cov: arraylike, covariance matrix of distribution
allow_singular: not supported
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.multivariate_normal.pdf`
"""
if allow_singular is not None:
raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf")
x, mean, cov = promote_dtypes_inexact(x, mean, cov)
if not mean.shape:
return (-1/2 * jnp.square(x - mean) / cov
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
n = mean.shape[-1]
if not np.shape(cov):
y = x - mean
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) / cov
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
L = lax.linalg.cholesky(cov)
y = jnp_vectorize.vectorize(
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
signature="(n,n),(n)->(n)"
)(L, x - mean)
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
r"""Multivariate normal probability distribution function.
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``pdf``.
The multivariate normal PDF is defined as
.. math::
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
:math:`k` is the rank of :math:`\Sigma`.
Args:
x: arraylike, value at which to evaluate the PDF
mean: arraylike, centroid of distribution
cov: arraylike, covariance matrix of distribution
allow_singular: not supported
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.multivariate_normal.logpdf`
"""
return lax.exp(logpdf(x, mean, cov))
@@ -0,0 +1,86 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Negative-binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.nbinom` ``logpmf``.
The negative-binomial probability mass function is given by
.. math::
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.nbinom.pmf`
"""
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
one = _lax_const(k, 1)
y = lax.sub(k, loc)
comb_term = lax.sub(
lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))
)
log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
log_probs = lax.add(comb_term, log_linear_term)
return jnp.where(lax.lt(k, loc), -np.inf, log_probs)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Negative-binomial probability mass function.
JAX implementation of :obj:`scipy.stats.nbinom` ``pmf``.
The negative-binomial probability mass function is given by
.. math::
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.nbinom.logpmf`
"""
return lax.exp(logpmf(k, n, p, loc))
@@ -0,0 +1,284 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy import special
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``logpdf``.
The normal distribution pdf is given by
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
scale_sqrd = lax.square(scale)
log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal probability distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``pdf``.
The normal distribution pdf is given by
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.norm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
return special.ndtr(lax.div(lax.sub(x, loc), scale))
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.norm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution percent point function.
JAX implementation of :obj:`scipy.stats.norm` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.norm.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
"""
return jnp.asarray(special.ndtri(q) * scale + loc, float)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.norm` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.norm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution survival function.
JAX implementation of :obj:`scipy.stats.norm` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.norm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
return cdf(-x, -loc, scale)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.norm` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.norm.sf`.
Args:
x: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.ppf`
"""
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
@@ -0,0 +1,313 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log probability distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``.
The Pareto probability density function is given by
.. math::
f(x, b) = \begin{cases}
bx^{-(b+1)} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
one = _lax_const(x, 1)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.div(scale, b))
log_probs = lax.neg(
lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))
)
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, log_probs)
def pdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto probability distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``pdf``.
The Pareto probability density function is given by
.. math::
f(x, b) = \begin{cases}
bx^{-(b+1)} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
return lax.exp(logpdf(x, b, loc, scale))
def cdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto cumulative distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``cdf``.
The Pareto cumulative distribution function is given by
.. math::
F(x, b) = \begin{cases}
1 - x^{-b} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of CDF values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.cdf", x, b, loc, scale)
one = _lax_const(x, 1)
zero = _lax_const(x, 0)
scaled_x = lax.div(lax.sub(x, loc), scale)
cdf = lax.sub(one, lax.pow(scaled_x, lax.neg(b)))
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, cdf)
def logcdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``logcdf``.
The Pareto cumulative distribution function is given by
.. math::
F(x, b) = \begin{cases}
1 - x^{-b} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logCDF values.
See Also:
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logcdf", x, b, loc, scale)
scaled_x = lax.div(lax.sub(x, loc), scale)
logcdf_val = lax.log1p(lax.neg(lax.pow(scaled_x, lax.neg(b))))
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, logcdf_val)
def logsf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log survival function.
JAX implementation of :obj:`scipy.stats.pareto` ``logsf``.
The Pareto survival function is given by
.. math::
S(x, b) = \begin{cases}
x^{-b} & x \ge 1\\
1 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the survival function
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of log survival function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logsf", x, b, loc, scale)
zero = _lax_const(x, 0)
scaled_x = lax.div(lax.sub(x, loc), scale)
logsf_val = lax.neg(lax.mul(b, lax.log(scaled_x)))
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, logsf_val)
def sf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto survival function.
JAX implementation of :obj:`scipy.stats.pareto` ``sf``.
The Pareto survival function is given by
.. math::
S(x, b) = \begin{cases}
x^{-b} & x \ge 1\\
1 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the survival function
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of survival function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
"""
return lax.exp(logsf(x, b, loc, scale))
def ppf(
q: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto percent point function (inverse CDF).
JAX implementation of :obj:`scipy.stats.pareto` ``ppf``.
The Pareto percent point function is the inverse of the Pareto CDF, and is
given by
.. math::
F^{-1}(q, b) = \begin{cases}
(1 - q)^{-1/b} & 0 \le q < 1\\
\text{NaN} & \text{otherwise}
\end{cases}
and is defined for :math:`b > 0`.
Args:
q: arraylike, value at which to evaluate the inverse CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of percent point function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.sf`
"""
q, b, loc, scale = promote_args_inexact("pareto.ppf", q, b, loc, scale)
one = _lax_const(q, 1)
ppf_val = lax.add(
loc, lax.mul(scale, lax.pow(lax.sub(one, q), lax.neg(lax.div(one, b))))
)
return jnp.where(jnp.isnan(q) | (q < 0) | (q > 1), np.nan, ppf_val)
@@ -0,0 +1,241 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike
from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson log probability mass function.
JAX implementation of :obj:`scipy.stats.poisson` ``logpmf``.
The Poisson probability mass function is given by
.. math::
f(k) = e^{-\mu}\frac{\mu^k}{k!}
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
Args:
k: arraylike, value at which to evaluate the PMF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
- :func:`jax.scipy.stats.poisson.cdf`
- :func:`jax.scipy.stats.poisson.pmf`
"""
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
return jnp.where(jnp.logical_or(lax.lt(x, zero),
lax.ne(jnp.round(k), k)), -np.inf, log_probs)
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson probability mass function.
JAX implementation of :obj:`scipy.stats.poisson` ``pmf``.
The Poisson probability mass function is given by
.. math::
f(k) = e^{-\mu}\frac{\mu^k}{k!}
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
Args:
k: arraylike, value at which to evaluate the PMF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
- :func:`jax.scipy.stats.poisson.cdf`
- :func:`jax.scipy.stats.poisson.logpmf`
"""
return jnp.exp(logpmf(k, mu, loc))
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson cumulative distribution function.
JAX implementation of :obj:`scipy.stats.poisson` ``cdf``.
The cumulative distribution function is defined as:
.. math::
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
where :math:`f_{pmf}(k, p)` is the probability mass function
:func:`jax.scipy.stats.poisson.pmf`.
Args:
k: arraylike, value at which to evaluate the CDF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.poisson.pmf`
- :func:`jax.scipy.stats.poisson.logpmf`
"""
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
p = gammaincc(jnp.floor(1 + x), mu)
return jnp.where(lax.lt(x, zero), zero, p)
@api.jit
def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Shannon entropy of the Poisson distribution.
JAX implementation of :obj:`scipy.stats.poisson` ``entropy``.
The entropy :math:`H(X)` of a Poisson random variable
:math:`X \sim \text{Poisson}(\mu)` is defined as:
.. math::
H(X) = -\sum_{k=0}^\infty p(k) \log p(k)
where :math:`p(k) = e^{-\mu} \mu^k / k!` for
:math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`.
This implementation uses **regime switching** for numerical stability
and performance:
- **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive
upper bound :math:`k \leq \mu + 20`
- **Medium** :math:`10 \leq \mu < 100`: Summation with bound
:math:`k \leq \mu + 10\sqrt{\mu} + 20`
- **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation:
:math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}`
Matches SciPy to relative error :math:`< 10^{-5}` across all regimes.
Args:
mu: arraylike, mean parameter of the Poisson distribution.
Must be ``> 0``.
loc: arraylike, optional location parameter (default: 0).
Accepted for API compatibility with scipy but does not
affect the entropy
Returns:
Array of entropy values with shape broadcast from ``mu`` and ``loc``.
Returns ``NaN`` for ``mu <= 0``.
Examples:
>>> from jax.scipy.stats import poisson
>>> poisson.entropy(5.0)
Array(2.204394, dtype=float32)
>>> poisson.entropy(jax.numpy.array([1, 10, 100]))
Array([1.3048419, 2.561407 , 3.7206903], dtype=float32)
See Also:
- :func:`jax.scipy.stats.poisson.pmf`
- :func:`jax.scipy.stats.poisson.logpmf`
- :obj:`scipy.stats.poisson`
"""
mu, loc = ensure_arraylike("poisson.entropy", mu, loc)
promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc)
#Note: loc does not affect the entropy - translation invariant
#it has only been taken to maintain compatibility with scipy api
result_shape = jnp.broadcast_shapes(
promoted_mu.shape,
promoted_loc.shape
)
mu_flat = jnp.ravel(promoted_mu)
zero_result = jnp.zeros_like(mu_flat)
# Choose the computation regime based on mu value
result = jnp.where(
mu_flat == 0,
zero_result,
jnp.where(
mu_flat < 10,
_entropy_small_mu(mu_flat),
jnp.where(
mu_flat < 100,
_entropy_medium_mu(mu_flat),
_entropy_large_mu(mu_flat)
)
)
)
result_mu_shape = jnp.reshape(result, promoted_mu.shape)
# Restore original shape
return jnp.broadcast_to(result_mu_shape, result_shape)
def _entropy_small_mu(mu: Array) -> Array:
"""Entropy via direct PMF summation for small μ (< 10).
Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass.
"""
max_k = 35
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
probs = pmf(k, mu, 0)
# Mask: only compute up to mu + 20 for each value
upper_bounds = jnp.ceil(mu + 20).astype(k.dtype)
mask = k < upper_bounds[None, :]
probs_masked = jnp.where(mask, probs, 0.0)
return jnp.sum(entr(probs_masked), axis=0)
def _entropy_medium_mu(mu: Array) -> Array:
"""Entropy for medium mu (10-100): Adaptive bounds based on std dev.
Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility.
"""
max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
probs = pmf(k, mu, 0)
upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype)
mask = k < upper_bounds[None, :]
probs_masked = jnp.where(mask, probs, 0.0)
return jnp.sum(entr(probs_masked), axis=0)
def _entropy_large_mu(mu: Array) -> Array:
"""Entropy for large mu (>= 100): Asymptotic approximation.
Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2)
"""
return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu)
@@ -0,0 +1,89 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Student's T log probability distribution function.
JAX implementation of :obj:`scipy.stats.t` ``logpdf``.
The Student's T probability distribution function is given by
.. math::
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.t.pdf`
"""
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
two = _lax_const(x, 2)
scaled_x = lax.div(lax.sub(x, loc), scale)
df_over_two = lax.div(df, two)
df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
lax.lgamma(df_plus_one_over_two))
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Student's T probability distribution function.
JAX implementation of :obj:`scipy.stats.t` ``pdf``.
The Student's T probability distribution function is given by
.. math::
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array
See Also:
:func:`jax.scipy.stats.t.logpdf`
"""
return lax.exp(logpdf(x, df, loc, scale))
@@ -0,0 +1,296 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
def _log_diff(x, y):
return logsumexp(
jnp.array([x, y]),
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
axis=0
)
def _log_gauss_mass(a, b):
"""Log of Gaussian probability mass within an interval"""
a, b = jnp.array(a), jnp.array(b)
a, b = jnp.broadcast_arrays(a, b)
# Note: Docstring carried over from scipy
# Calculations in right tail are inaccurate, so we'll exploit the
# symmetry and work only in the left tail
case_left = b <= 0
case_right = a > 0
case_central = ~(case_left | case_right)
# By conditionally swapping arguments if we're in the right tail,
# we only need to compile the mass_case_left graph once instead of twice.
a_tail = jnp.where(case_right, -b, a)
b_tail = jnp.where(case_right, -a, b)
mass_tail = _log_diff(log_ndtr(b_tail), log_ndtr(a_tail))
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
# Correct for this with an alternative formulation.
# We're not concerned with underflow here: if only one term
# underflows, it was insignificant; if both terms underflow,
# the result can't accurately be represented in logspace anyway
# because sc.log1p(x) ~ x for small x.
mass_central = jnp.log1p(-ndtr(a) - ndtr(-b))
out = jnp.where(case_central, mass_central, mass_tail)
return out
@api.jit
def logpdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from zero. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
val = jnp.where((x_scaled < a) | (x_scaled > b), -np.inf, val)
val = jnp.where(a >= b, np.nan, val)
return val
def pdf(x, a, b, loc=0, scale=1):
r"""Truncated normal probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from the centroid. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logpdf(x, a, b, loc, scale))
@api.jit
def logsf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
return logcdf(-x, -b, -a, -loc, scale)
@api.jit
def sf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
return lax.exp(logsf(x, a, b, loc, scale))
@api.jit
def logcdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
lgm_ab = _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - lgm_ab
logsf = _log_gauss_mass(x, b) - lgm_ab
logcdf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logcdf > -0.1, x > a],
[0, -np.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
)
logcdf = jnp.where(a >= b, np.nan, logcdf)
return logcdf
@api.jit
def cdf(x, a, b, loc=0, scale=1):
r"""Truncated normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logcdf(x, a, b, loc, scale))
@@ -0,0 +1,148 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import promote_args_inexact
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform log probability distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``logpdf``.
The uniform distribution pdf is given by
.. math::
f(x) = \begin{cases}
1 & 0 \le x \le 1 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)),
-np.inf, log_probs)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform probability distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``pdf``.
The uniform distribution pdf is given by
.. math::
f(x) = \begin{cases}
1 & 0 \le x \le 1 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.logpdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform cumulative distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.uniform.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.logpdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
return jnp.select(conds, vals)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Uniform distribution percent point function.
JAX implementation of :obj:`scipy.stats.uniform` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.uniform.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.logpdf`
"""
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
np.nan,
lax.add(loc, lax.mul(scale, q))
)
@@ -0,0 +1,79 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
r"""von Mises log probability distribution function.
JAX implementation of :obj:`scipy.stats.vonmises` ``logpdf``.
The von Mises probability distribution function is given by
.. math::
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
:math:`-\pi \le x \le \pi`.
Args:
x: arraylike, value at which to evaluate the PDF
kappa: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.vonmises.pdf`
"""
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
zero = _lax_const(kappa, 0)
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * np.pi * lax.bessel_i0e(kappa)), np.nan)
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
r"""von Mises probability distribution function.
JAX implementation of :obj:`scipy.stats.vonmises` ``pdf``.
The von Mises probability distribution function is given by
.. math::
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
:math:`-\pi \le x \le \pi`.
Args:
x: arraylike, value at which to evaluate the PDF
kappa: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.vonmises.logpdf`
"""
return lax.exp(logpdf(x, kappa))
@@ -0,0 +1,82 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
r"""Wrapped Cauchy log probability distribution function.
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``logpdf``.
The wrapped Cauchy probability distribution function is given by
.. math::
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
Args:
x: arraylike, value at which to evaluate the PDF
c: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.wrapcauchy.pdf`
"""
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
return jnp.where(
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
jnp.where(
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, np.pi * 2)),
jnp.log(1 - c * c) - jnp.log(2 * np.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
-np.inf,
),
np.nan,
)
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
r"""Wrapped Cauchy probability distribution function.
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``pdf``.
The wrapped Cauchy probability distribution function is given by
.. math::
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
Args:
x: arraylike, value at which to evaluate the PDF
c: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.wrapcauchy.logpdf`
"""
return lax.exp(logpdf(x, c))